{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "524620c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp losses.pytorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15392f6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd532cb1-d11d-468e-a0e5-eb1101ba6662",
   "metadata": {},
   "source": [
    "# PyTorch Losses\n",
    "\n",
    "> NeuralForecast contains a collection PyTorch Loss classes aimed to be used during the models' optimization."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "096cfbec-1d59-454a-b572-5890103b2f1f",
   "metadata": {},
   "source": [
    "The most important train signal is the forecast error, which is the difference between the observed value $y_{\\tau}$ and the prediction $\\hat{y}_{\\tau}$, at time $y_{\\tau}$:\n",
    "\n",
    "$$e_{\\tau} = y_{\\tau}-\\hat{y}_{\\tau} \\qquad \\qquad \\tau \\in \\{t+1,\\dots,t+H \\}$$\n",
    "\n",
    "The train loss summarizes the forecast errors in different train optimization objectives.\n",
    "\n",
    "All the losses are `torch.nn.modules` which helps to automatically moved them across CPU/GPU/TPU devices with Pytorch Lightning. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acfa68dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from typing import Optional, Union, Tuple\n",
    "\n",
    "import math\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "import torch.nn.functional as F\n",
    "from torch.distributions import Distribution\n",
    "from torch.distributions import (\n",
    "    Bernoulli,\n",
    "    Normal, \n",
    "    StudentT, \n",
    "    Poisson,\n",
    "    NegativeBinomial\n",
    ")\n",
    "\n",
    "from torch.distributions import constraints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2508f7a9-1433-4ad8-8f2f-0078c6ed6c3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import matplotlib.pyplot as plt\n",
    "from fastcore.test import test_eq\n",
    "from nbdev.showdoc import show_doc\n",
    "from neuralforecast.utils import generate_series"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84e07e98-b4c8-4ade-b3b6-1d27f367aa0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def _divide_no_nan(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Auxiliary funtion to handle divide by 0\n",
    "    \"\"\"\n",
    "    div = a / b\n",
    "    div[div != div] = 0.0\n",
    "    div[div == float('inf')] = 0.0\n",
    "    return div"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "132db0ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def _weighted_mean(losses, weights):\n",
    "    \"\"\"\n",
    "    Compute weighted mean of losses per datapoint.\n",
    "    \"\"\"\n",
    "    return _divide_no_nan(torch.sum(losses * weights), torch.sum(weights))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f41562a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class BasePointLoss(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    Base class for point loss functions.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>\n",
    "    `outputsize_multiplier`: Multiplier for the output size. <br>\n",
    "    `output_names`: Names of the outputs. <br>\n",
    "    \"\"\"\n",
    "    def __init__(self, horizon_weight, outputsize_multiplier, output_names):\n",
    "        super(BasePointLoss, self).__init__()\n",
    "        if horizon_weight is not None:\n",
    "            horizon_weight = torch.Tensor(horizon_weight.flatten())\n",
    "        self.horizon_weight = horizon_weight\n",
    "        self.outputsize_multiplier = outputsize_multiplier\n",
    "        self.output_names = output_names\n",
    "        self.is_distribution_output = False\n",
    "\n",
    "    def domain_map(self, y_hat: torch.Tensor):\n",
    "        \"\"\"\n",
    "        Univariate loss operates in dimension [B,T,H]/[B,H]\n",
    "        This changes the network's output from [B,H,1]->[B,H]\n",
    "        \"\"\"\n",
    "        return y_hat.squeeze(-1)\n",
    "\n",
    "    def _compute_weights(self, y, mask):\n",
    "        \"\"\"\n",
    "        Compute final weights for each datapoint (based on all weights and all masks)\n",
    "        Set horizon_weight to a ones[H] tensor if not set.\n",
    "        If set, check that it has the same length as the horizon in x.\n",
    "        \"\"\"\n",
    "        if mask is None:\n",
    "            mask = torch.ones_like(y).to(y.device)\n",
    "\n",
    "        if self.horizon_weight is None:\n",
    "            self.horizon_weight = torch.ones(mask.shape[-1])\n",
    "        else:\n",
    "            assert mask.shape[-1] == len(self.horizon_weight), \\\n",
    "                'horizon_weight must have same length as Y'\n",
    "\n",
    "        weights = self.horizon_weight.clone()\n",
    "        weights = torch.ones_like(mask, device=mask.device) * weights.to(mask.device)\n",
    "        return weights * mask"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b8a94d7d",
   "metadata": {},
   "source": [
    "# 1. Scale-dependent Errors\n",
    "\n",
    "These metrics are on the same scale as the data."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "82fc4679",
   "metadata": {},
   "source": [
    "## Mean Absolute Error (MAE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e413fae-c590-4713-aab9-37c61ed37dff",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class MAE(BasePointLoss):\n",
    "    \"\"\"Mean Absolute Error\n",
    "\n",
    "    Calculates Mean Absolute Error between\n",
    "    `y` and `y_hat`. MAE measures the relative prediction\n",
    "    accuracy of a forecasting method by calculating the\n",
    "    deviation of the prediction and the true\n",
    "    value at a given time and averages these devations\n",
    "    over the length of the series.\n",
    "\n",
    "    $$ \\mathrm{MAE}(\\\\mathbf{y}_{\\\\tau}, \\\\mathbf{\\hat{y}}_{\\\\tau}) = \\\\frac{1}{H} \\\\sum^{t+H}_{\\\\tau=t+1} |y_{\\\\tau} - \\hat{y}_{\\\\tau}| $$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>\n",
    "    \"\"\"    \n",
    "    def __init__(self, horizon_weight=None):\n",
    "        super(MAE, self).__init__(horizon_weight=horizon_weight,\n",
    "                                  outputsize_multiplier=1,\n",
    "                                  output_names=[''])\n",
    "\n",
    "    def __call__(self,\n",
    "                 y: torch.Tensor,\n",
    "                 y_hat: torch.Tensor,\n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor, Actual values.<br>\n",
    "        `y_hat`: tensor, Predicted values.<br>\n",
    "        `mask`: tensor, Specifies datapoints to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `mae`: tensor (single value).\n",
    "        \"\"\"\n",
    "        losses = torch.abs(y - y_hat)\n",
    "        weights = self._compute_weights(y=y, mask=mask)\n",
    "        return _weighted_mean(losses=losses, weights=weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d004cd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(MAE, name='MAE.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a20a273",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(MAE.__call__, name='MAE.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "0292c74d",
   "metadata": {},
   "source": [
    "![](imgs_losses/mae_loss.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "4f31cc3d",
   "metadata": {},
   "source": [
    "## Mean Squared Error (MSE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46cfe937",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class MSE(BasePointLoss):\n",
    "    \"\"\"  Mean Squared Error\n",
    "\n",
    "    Calculates Mean Squared Error between\n",
    "    `y` and `y_hat`. MSE measures the relative prediction\n",
    "    accuracy of a forecasting method by calculating the \n",
    "    squared deviation of the prediction and the true\n",
    "    value at a given time, and averages these devations\n",
    "    over the length of the series.\n",
    "    \n",
    "    $$ \\mathrm{MSE}(\\\\mathbf{y}_{\\\\tau}, \\\\mathbf{\\hat{y}}_{\\\\tau}) = \\\\frac{1}{H} \\\\sum^{t+H}_{\\\\tau=t+1} (y_{\\\\tau} - \\hat{y}_{\\\\tau})^{2} $$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>\n",
    "    \"\"\"\n",
    "    def __init__(self, horizon_weight=None):\n",
    "        super(MSE, self).__init__(horizon_weight=horizon_weight,\n",
    "                                  outputsize_multiplier=1,\n",
    "                                  output_names=[''])\n",
    "\n",
    "    def __call__(self,\n",
    "                 y: torch.Tensor,\n",
    "                 y_hat: torch.Tensor,\n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor, Actual values.<br>\n",
    "        `y_hat`: tensor, Predicted values.<br>\n",
    "        `mask`: tensor, Specifies datapoints to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `mse`: tensor (single value).\n",
    "        \"\"\"\n",
    "        losses = (y - y_hat)**2\n",
    "        weights = self._compute_weights(y=y, mask=mask)\n",
    "        return _weighted_mean(losses=losses, weights=weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8c65b82",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(MSE, name='MSE.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0126a7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(MSE.__call__, name='MSE.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7b23f9c1",
   "metadata": {},
   "source": [
    "![](imgs_losses/mse_loss.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b160140b",
   "metadata": {},
   "source": [
    "## Root Mean Squared Error (RMSE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "545ebfb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class RMSE(BasePointLoss):\n",
    "    \"\"\" Root Mean Squared Error\n",
    "\n",
    "    Calculates Root Mean Squared Error between\n",
    "    `y` and `y_hat`. RMSE measures the relative prediction\n",
    "    accuracy of a forecasting method by calculating the squared deviation\n",
    "    of the prediction and the observed value at a given time and\n",
    "    averages these devations over the length of the series.\n",
    "    Finally the RMSE will be in the same scale\n",
    "    as the original time series so its comparison with other\n",
    "    series is possible only if they share a common scale. \n",
    "    RMSE has a direct connection to the L2 norm.\n",
    "    \n",
    "    $$ \\mathrm{RMSE}(\\\\mathbf{y}_{\\\\tau}, \\\\mathbf{\\hat{y}}_{\\\\tau}) = \\\\sqrt{\\\\frac{1}{H} \\\\sum^{t+H}_{\\\\tau=t+1} (y_{\\\\tau} - \\hat{y}_{\\\\tau})^{2}} $$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>\n",
    "    \"\"\"\n",
    "    def __init__(self, horizon_weight=None):\n",
    "        super(RMSE, self).__init__(horizon_weight=horizon_weight,\n",
    "                                  outputsize_multiplier=1,\n",
    "                                  output_names=[''])\n",
    "\n",
    "    def __call__(self,\n",
    "                 y: torch.Tensor,\n",
    "                 y_hat: torch.Tensor,\n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor, Actual values.<br>\n",
    "        `y_hat`: tensor, Predicted values.<br>\n",
    "        `mask`: tensor, Specifies datapoints to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `rmse`: tensor (single value).\n",
    "        \"\"\"\n",
    "        losses = (y - y_hat)**2\n",
    "        weights = self._compute_weights(y=y, mask=mask)\n",
    "        losses = _weighted_mean(losses=losses, weights=weights)\n",
    "        return torch.sqrt(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d961d383",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(RMSE, name='RMSE.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d398d3e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(RMSE.__call__, name='RMSE.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "d4539e38",
   "metadata": {},
   "source": [
    "![](imgs_losses/rmse_loss.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8bcf5488",
   "metadata": {},
   "source": [
    "# 2. Percentage errors\n",
    "\n",
    "These metrics are unit-free, suitable for comparisons across series."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8eab97ec",
   "metadata": {},
   "source": [
    "## Mean Absolute Percentage Error (MAPE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adecb6bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class MAPE(BasePointLoss):\n",
    "    \"\"\" Mean Absolute Percentage Error\n",
    "\n",
    "    Calculates Mean Absolute Percentage Error  between\n",
    "    `y` and `y_hat`. MAPE measures the relative prediction\n",
    "    accuracy of a forecasting method by calculating the percentual deviation\n",
    "    of the prediction and the observed value at a given time and\n",
    "    averages these devations over the length of the series.\n",
    "    The closer to zero an observed value is, the higher penalty MAPE loss\n",
    "    assigns to the corresponding error.\n",
    "\n",
    "    $$ \\mathrm{MAPE}(\\\\mathbf{y}_{\\\\tau}, \\\\mathbf{\\hat{y}}_{\\\\tau}) = \\\\frac{1}{H} \\\\sum^{t+H}_{\\\\tau=t+1} \\\\frac{|y_{\\\\tau}-\\hat{y}_{\\\\tau}|}{|y_{\\\\tau}|} $$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>\n",
    "\n",
    "    **References:**<br>\n",
    "    [Makridakis S., \"Accuracy measures: theoretical and practical concerns\".](https://www.sciencedirect.com/science/article/pii/0169207093900793)    \n",
    "    \"\"\"\n",
    "    def __init__(self, horizon_weight=None):\n",
    "        super(MAPE, self).__init__(horizon_weight=horizon_weight,\n",
    "                                  outputsize_multiplier=1,\n",
    "                                  output_names=[''])\n",
    "\n",
    "    def __call__(self,\n",
    "                 y: torch.Tensor,\n",
    "                 y_hat: torch.Tensor,\n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor, Actual values.<br>\n",
    "        `y_hat`: tensor, Predicted values.<br>\n",
    "        `mask`: tensor, Specifies date stamps per serie to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `mape`: tensor (single value).\n",
    "        \"\"\"\n",
    "        scale = _divide_no_nan(torch.ones_like(y, device=y.device), torch.abs(y))\n",
    "        losses = torch.abs(y - y_hat) * scale\n",
    "        weights = self._compute_weights(y=y, mask=mask)\n",
    "        mape = _weighted_mean(losses=losses, weights=weights)\n",
    "        return mape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "174e8042",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(MAPE, name='MAPE.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da63f136",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(MAPE.__call__, name='MAPE.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c8ccdc69",
   "metadata": {},
   "source": [
    "![](imgs_losses/mape_loss.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "cb245891",
   "metadata": {},
   "source": [
    "## Symmetric MAPE (sMAPE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7566e649",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class SMAPE(BasePointLoss):\n",
    "    \"\"\" Symmetric Mean Absolute Percentage Error\n",
    "\n",
    "    Calculates Symmetric Mean Absolute Percentage Error between\n",
    "    `y` and `y_hat`. SMAPE measures the relative prediction\n",
    "    accuracy of a forecasting method by calculating the relative deviation\n",
    "    of the prediction and the observed value scaled by the sum of the\n",
    "    absolute values for the prediction and observed value at a\n",
    "    given time, then averages these devations over the length\n",
    "    of the series. This allows the SMAPE to have bounds between\n",
    "    0% and 200% which is desireble compared to normal MAPE that\n",
    "    may be undetermined when the target is zero.\n",
    "\n",
    "    $$ \\mathrm{sMAPE}_{2}(\\\\mathbf{y}_{\\\\tau}, \\\\mathbf{\\hat{y}}_{\\\\tau}) = \\\\frac{1}{H} \\\\sum^{t+H}_{\\\\tau=t+1} \\\\frac{|y_{\\\\tau}-\\hat{y}_{\\\\tau}|}{|y_{\\\\tau}|+|\\hat{y}_{\\\\tau}|} $$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>\n",
    "\n",
    "    **References:**<br>\n",
    "    [Makridakis S., \"Accuracy measures: theoretical and practical concerns\".](https://www.sciencedirect.com/science/article/pii/0169207093900793)\n",
    "    \"\"\"\n",
    "    def __init__(self, horizon_weight=None):\n",
    "        super(SMAPE, self).__init__(horizon_weight=horizon_weight,\n",
    "                                  outputsize_multiplier=1,\n",
    "                                  output_names=[''])\n",
    "\n",
    "    def __call__(self,\n",
    "                 y: torch.Tensor,\n",
    "                 y_hat: torch.Tensor,\n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor, Actual values.<br>\n",
    "        `y_hat`: tensor, Predicted values.<br>\n",
    "        `mask`: tensor, Specifies date stamps per serie to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `smape`: tensor (single value).\n",
    "        \"\"\"\n",
    "        delta_y = torch.abs((y - y_hat))\n",
    "        scale = torch.abs(y) + torch.abs(y_hat)\n",
    "        losses = _divide_no_nan(delta_y, scale)\n",
    "        weights = self._compute_weights(y=y, mask=mask)\n",
    "        return 2*_weighted_mean(losses=losses, weights=weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dee99fb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(SMAPE, name='SMAPE.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db62a845",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(SMAPE.__call__, name='SMAPE.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "bc3f2d6f",
   "metadata": {},
   "source": [
    "# 3. Scale-independent Errors\n",
    "\n",
    "These metrics measure the relative improvements versus baselines."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5b2dee1f",
   "metadata": {},
   "source": [
    "## Mean Absolute Scaled Error (MASE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cc34fae",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class MASE(BasePointLoss):\n",
    "    \"\"\" Mean Absolute Scaled Error \n",
    "    Calculates the Mean Absolute Scaled Error between\n",
    "    `y` and `y_hat`. MASE measures the relative prediction\n",
    "    accuracy of a forecasting method by comparinng the mean absolute errors\n",
    "    of the prediction and the observed value against the mean\n",
    "    absolute errors of the seasonal naive model.\n",
    "    The MASE partially composed the Overall Weighted Average (OWA), \n",
    "    used in the M4 Competition.\n",
    "    \n",
    "    $$ \\mathrm{MASE}(\\\\mathbf{y}_{\\\\tau}, \\\\mathbf{\\hat{y}}_{\\\\tau}, \\\\mathbf{\\hat{y}}^{season}_{\\\\tau}) = \\\\frac{1}{H} \\sum^{t+H}_{\\\\tau=t+1} \\\\frac{|y_{\\\\tau}-\\hat{y}_{\\\\tau}|}{\\mathrm{MAE}(\\\\mathbf{y}_{\\\\tau}, \\\\mathbf{\\hat{y}}^{season}_{\\\\tau})} $$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `seasonality`: int. Main frequency of the time series; Hourly 24,  Daily 7, Weekly 52, Monthly 12, Quarterly 4, Yearly 1.\n",
    "    `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>\n",
    "    \n",
    "    **References:**<br>\n",
    "    [Rob J. Hyndman, & Koehler, A. B. \"Another look at measures of forecast accuracy\".](https://www.sciencedirect.com/science/article/pii/S0169207006000239)<br>\n",
    "    [Spyros Makridakis, Evangelos Spiliotis, Vassilios Assimakopoulos, \"The M4 Competition: 100,000 time series and 61 forecasting methods\".](https://www.sciencedirect.com/science/article/pii/S0169207019301128)\n",
    "    \"\"\"\n",
    "    def __init__(self, seasonality: int, horizon_weight=None):\n",
    "        super(MASE, self).__init__(horizon_weight=horizon_weight,\n",
    "                                   outputsize_multiplier=1,\n",
    "                                   output_names=[''])\n",
    "        self.seasonality = seasonality\n",
    "\n",
    "    def __call__(self,\n",
    "                 y: torch.Tensor,\n",
    "                 y_hat: torch.Tensor,\n",
    "                 y_insample: torch.Tensor,\n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor (batch_size, output_size), Actual values.<br>\n",
    "        `y_hat`: tensor (batch_size, output_size)), Predicted values.<br>\n",
    "        `y_insample`: tensor (batch_size, input_size), Actual insample Seasonal Naive predictions.<br>\n",
    "        `mask`: tensor, Specifies date stamps per serie to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `mase`: tensor (single value).\n",
    "        \"\"\"\n",
    "        delta_y = torch.abs(y - y_hat)\n",
    "        scale = torch.mean(torch.abs(y_insample[:, self.seasonality:] - \\\n",
    "                                     y_insample[:, :-self.seasonality]), axis=1)\n",
    "        losses = _divide_no_nan(delta_y, scale[:, None])\n",
    "        weights = self._compute_weights(y=y, mask=mask)\n",
    "        return _weighted_mean(losses=losses, weights=weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6a4cf21",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(MASE, name='MASE.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32a2c11b",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(MASE.__call__, name='MASE.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "6e0c8fe5",
   "metadata": {},
   "source": [
    "![](imgs_losses/mase_loss.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "73bbdc4e",
   "metadata": {},
   "source": [
    "## Relative Mean Squared Error (relMSE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "954911d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class relMSE(BasePointLoss):\n",
    "    \"\"\"Relative Mean Squared Error\n",
    "    Computes Relative Mean Squared Error (relMSE), as proposed by Hyndman & Koehler (2006)\n",
    "    as an alternative to percentage errors, to avoid measure unstability.\n",
    "    $$ \\mathrm{relMSE}(\\\\mathbf{y}, \\\\mathbf{\\hat{y}}, \\\\mathbf{\\hat{y}}^{naive1}) =\n",
    "    \\\\frac{\\mathrm{MSE}(\\\\mathbf{y}, \\\\mathbf{\\hat{y}})}{\\mathrm{MSE}(\\\\mathbf{y}, \\\\mathbf{\\hat{y}}^{naive1})} $$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `y_train`: numpy array, Training values.<br>\n",
    "    `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>\n",
    "\n",
    "    **References:**<br>\n",
    "    - [Hyndman, R. J and Koehler, A. B. (2006).\n",
    "       \"Another look at measures of forecast accuracy\",\n",
    "       International Journal of Forecasting, Volume 22, Issue 4.](https://www.sciencedirect.com/science/article/pii/S0169207006000239)<br>\n",
    "    - [Kin G. Olivares, O. Nganba Meetei, Ruijun Ma, Rohan Reddy, Mengfei Cao, Lee Dicker. \n",
    "       \"Probabilistic Hierarchical Forecasting with Deep Poisson Mixtures. \n",
    "       Submitted to the International Journal Forecasting, Working paper available at arxiv.](https://arxiv.org/pdf/2110.13179.pdf)\n",
    "    \"\"\"\n",
    "    def __init__(self, y_train, horizon_weight=None):\n",
    "        super(relMSE, self).__init__(horizon_weight=horizon_weight,\n",
    "                                     outputsize_multiplier=1,\n",
    "                                     output_names=[''])\n",
    "        self.y_train = y_train\n",
    "        self.mse = MSE(horizon_weight=horizon_weight)\n",
    "\n",
    "    def __call__(self,\n",
    "                 y: torch.Tensor,\n",
    "                 y_hat: torch.Tensor,\n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor (batch_size, output_size), Actual values.<br>\n",
    "        `y_hat`: tensor (batch_size, output_size)), Predicted values.<br>\n",
    "        `y_insample`: tensor (batch_size, input_size), Actual insample Seasonal Naive predictions.<br>\n",
    "        `mask`: tensor, Specifies date stamps per serie to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `relMSE`: tensor (single value).\n",
    "        \"\"\"\n",
    "        horizon = y.shape[-1]\n",
    "        last_col = self.y_train[:, -1].unsqueeze(1)\n",
    "        y_naive = last_col.repeat(1, horizon)\n",
    "\n",
    "        norm = self.mse(y=y, y_hat=y_naive, mask=mask) # Already weighted\n",
    "        norm = norm + 1e-5 # Numerical stability\n",
    "        loss = self.mse(y=y, y_hat=y_hat, mask=mask) # Already weighted\n",
    "        loss = _divide_no_nan(loss, norm)\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edeb6f9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(relMSE, name='relMSE.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a317b5c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(relMSE.__call__, name='relMSE.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c828438e",
   "metadata": {},
   "source": [
    "# 4. Probabilistic Errors\n",
    "\n",
    "These methods use statistical approaches for estimating unknown probability distributions using observed data. \n",
    "\n",
    "Maximum likelihood estimation involves finding the parameter values that maximize the likelihood function, which measures the probability of obtaining the observed data given the parameter values. MLE has good theoretical properties and efficiency under certain satisfied assumptions.\n",
    "\n",
    "On the non-parametric approach, quantile regression measures non-symmetrically deviation, producing under/over estimation."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "999d8cb2",
   "metadata": {},
   "source": [
    "## Quantile Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd296fcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class QuantileLoss(BasePointLoss):\n",
    "    \"\"\" Quantile Loss\n",
    "\n",
    "    Computes the quantile loss between `y` and `y_hat`.\n",
    "    QL measures the deviation of a quantile forecast.\n",
    "    By weighting the absolute deviation in a non symmetric way, the\n",
    "    loss pays more attention to under or over estimation.\n",
    "    A common value for q is 0.5 for the deviation from the median (Pinball loss).\n",
    "\n",
    "    $$ \\mathrm{QL}(\\\\mathbf{y}_{\\\\tau}, \\\\mathbf{\\hat{y}}^{(q)}_{\\\\tau}) = \\\\frac{1}{H} \\\\sum^{t+H}_{\\\\tau=t+1} \\Big( (1-q)\\,( \\hat{y}^{(q)}_{\\\\tau} - y_{\\\\tau} )_{+} + q\\,( y_{\\\\tau} - \\hat{y}^{(q)}_{\\\\tau} )_{+} \\Big) $$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `q`: float, between 0 and 1. The slope of the quantile loss, in the context of quantile regression, the q determines the conditional quantile level.<br>\n",
    "    `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>\n",
    "\n",
    "    **References:**<br>\n",
    "    [Roger Koenker and Gilbert Bassett, Jr., \"Regression Quantiles\".](https://www.jstor.org/stable/1913643)\n",
    "    \"\"\"\n",
    "    def __init__(self, q, horizon_weight=None):\n",
    "        super(QuantileLoss, self).__init__(horizon_weight=horizon_weight,\n",
    "                                           outputsize_multiplier=1,\n",
    "                                           output_names=[f'_ql{q}'])\n",
    "        self.q = q\n",
    "\n",
    "    def __call__(self,\n",
    "                 y: torch.Tensor,\n",
    "                 y_hat: torch.Tensor,\n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor, Actual values.<br>\n",
    "        `y_hat`: tensor, Predicted values.<br>\n",
    "        `mask`: tensor, Specifies datapoints to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `quantile_loss`: tensor (single value).\n",
    "        \"\"\"\n",
    "        delta_y = y - y_hat\n",
    "        losses = torch.max(torch.mul(self.q, delta_y), torch.mul((self.q - 1), delta_y))\n",
    "        weights = self._compute_weights(y=y, mask=mask)\n",
    "        return _weighted_mean(losses=losses, weights=weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70bd46d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(QuantileLoss, name='QuantileLoss.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b1588e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(QuantileLoss.__call__, name='QuantileLoss.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "51ac874f",
   "metadata": {},
   "source": [
    "![](imgs_losses/q_loss.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "92dbb002",
   "metadata": {},
   "source": [
    "## Multi Quantile Loss (MQLoss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "291a0530",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def level_to_outputs(level):\n",
    "    qs = sum([[50-l/2, 50+l/2] for l in level], [])\n",
    "    output_names = sum([[f'-lo-{l}', f'-hi-{l}'] for l in level], [])\n",
    "\n",
    "    sort_idx = np.argsort(qs)\n",
    "    quantiles = np.array(qs)[sort_idx]\n",
    "\n",
    "    # Add default median\n",
    "    quantiles = np.concatenate([np.array([50]), quantiles])\n",
    "    quantiles = torch.Tensor(quantiles) / 100\n",
    "    output_names = list(np.array(output_names)[sort_idx])\n",
    "    output_names.insert(0, '-median')\n",
    "    \n",
    "    return quantiles, output_names\n",
    "\n",
    "def quantiles_to_outputs(quantiles):\n",
    "    output_names = []\n",
    "    for q in quantiles:\n",
    "        if q<.50:\n",
    "            output_names.append(f'-lo-{np.round(100-200*q,2)}')\n",
    "        elif q>.50:\n",
    "            output_names.append(f'-hi-{np.round(100-200*(1-q),2)}')\n",
    "        else:\n",
    "            output_names.append('-median')\n",
    "    return quantiles, output_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21dc7968",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class MQLoss(BasePointLoss):\n",
    "    \"\"\"  Multi-Quantile loss\n",
    "\n",
    "    Calculates the Multi-Quantile loss (MQL) between `y` and `y_hat`.\n",
    "    MQL calculates the average multi-quantile Loss for\n",
    "    a given set of quantiles, based on the absolute \n",
    "    difference between predicted quantiles and observed values.\n",
    "    \n",
    "    $$ \\mathrm{MQL}(\\\\mathbf{y}_{\\\\tau},[\\\\mathbf{\\hat{y}}^{(q_{1})}_{\\\\tau}, ... ,\\hat{y}^{(q_{n})}_{\\\\tau}]) = \\\\frac{1}{n} \\\\sum_{q_{i}} \\mathrm{QL}(\\\\mathbf{y}_{\\\\tau}, \\\\mathbf{\\hat{y}}^{(q_{i})}_{\\\\tau}) $$\n",
    "    \n",
    "    The limit behavior of MQL allows to measure the accuracy \n",
    "    of a full predictive distribution $\\mathbf{\\hat{F}}_{\\\\tau}$ with \n",
    "    the continuous ranked probability score (CRPS). This can be achieved \n",
    "    through a numerical integration technique, that discretizes the quantiles \n",
    "    and treats the CRPS integral with a left Riemann approximation, averaging over \n",
    "    uniformly distanced quantiles.    \n",
    "    \n",
    "    $$ \\mathrm{CRPS}(y_{\\\\tau}, \\mathbf{\\hat{F}}_{\\\\tau}) = \\int^{1}_{0} \\mathrm{QL}(y_{\\\\tau}, \\hat{y}^{(q)}_{\\\\tau}) dq $$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `level`: int list [0,100]. Probability levels for prediction intervals (Defaults median).\n",
    "    `quantiles`: float list [0., 1.]. Alternative to level, quantiles to estimate from y distribution.\n",
    "    `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>\n",
    "\n",
    "    **References:**<br>\n",
    "    [Roger Koenker and Gilbert Bassett, Jr., \"Regression Quantiles\".](https://www.jstor.org/stable/1913643)<br>\n",
    "    [James E. Matheson and Robert L. Winkler, \"Scoring Rules for Continuous Probability Distributions\".](https://www.jstor.org/stable/2629907)\n",
    "    \"\"\"\n",
    "    def __init__(self, level=[80, 90], quantiles=None, horizon_weight=None):\n",
    "\n",
    "        qs, output_names = level_to_outputs(level)\n",
    "        qs = torch.Tensor(qs)\n",
    "        # Transform quantiles to homogeneus output names\n",
    "        if quantiles is not None:\n",
    "            _, output_names = quantiles_to_outputs(quantiles)\n",
    "            qs = torch.Tensor(quantiles)\n",
    "\n",
    "        super(MQLoss, self).__init__(horizon_weight=horizon_weight,\n",
    "                                     outputsize_multiplier=len(qs),\n",
    "                                     output_names=output_names)\n",
    "        \n",
    "        self.quantiles = torch.nn.Parameter(qs, requires_grad=False)\n",
    "\n",
    "    def domain_map(self, y_hat: torch.Tensor):\n",
    "        \"\"\"\n",
    "        Identity domain map [B,T,H,Q]/[B,H,Q]\n",
    "        \"\"\"\n",
    "        return y_hat\n",
    "    \n",
    "    def _compute_weights(self, y, mask):\n",
    "        \"\"\"\n",
    "        Compute final weights for each datapoint (based on all weights and all masks)\n",
    "        Set horizon_weight to a ones[H] tensor if not set.\n",
    "        If set, check that it has the same length as the horizon in x.\n",
    "        \"\"\"\n",
    "        if mask is None:\n",
    "            mask = torch.ones_like(y).to(y.device)\n",
    "        else:\n",
    "            mask = mask.unsqueeze(1) # Add Q dimension.\n",
    "\n",
    "        if self.horizon_weight is None:\n",
    "            self.horizon_weight = torch.ones(mask.shape[-1])\n",
    "        else:\n",
    "            assert mask.shape[-1] == len(self.horizon_weight), \\\n",
    "                'horizon_weight must have same length as Y'\n",
    "    \n",
    "        weights = self.horizon_weight.clone()\n",
    "        weights = torch.ones_like(mask, device=mask.device) * weights.to(mask.device)\n",
    "        return weights * mask\n",
    "\n",
    "    def __call__(self,\n",
    "                 y: torch.Tensor,\n",
    "                 y_hat: torch.Tensor,\n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor, Actual values.<br>\n",
    "        `y_hat`: tensor, Predicted values.<br>\n",
    "        `mask`: tensor, Specifies date stamps per serie to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `mqloss`: tensor (single value).\n",
    "        \"\"\"\n",
    "        \n",
    "        error  = y_hat - y.unsqueeze(-1)\n",
    "        sq     = torch.maximum(-error, torch.zeros_like(error))\n",
    "        s1_q   = torch.maximum(error, torch.zeros_like(error))\n",
    "        losses = (1/len(self.quantiles))*(self.quantiles * sq + (1 - self.quantiles) * s1_q)\n",
    "\n",
    "        if y_hat.ndim == 3: # BaseWindows\n",
    "            losses = losses.swapaxes(-2,-1) # [B,H,Q] -> [B,Q,H] (needed for horizon weighting, H at the end)\n",
    "        elif y_hat.ndim == 4: # BaseRecurrent\n",
    "            losses = losses.swapaxes(-2,-1)\n",
    "            losses = losses.swapaxes(-2,-3) # [B,seq_len,H,Q] -> [B,Q,seq_len,H] (needed for horizon weighting, H at the end)\n",
    "\n",
    "        weights = self._compute_weights(y=losses, mask=mask) # Use losses for extra dim\n",
    "        # NOTE: Weights do not have Q dimension.\n",
    "\n",
    "        return _weighted_mean(losses=losses, weights=weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f42ec82",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(MQLoss, name='MQLoss.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bac2237a",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(MQLoss.__call__, name='MQLoss.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "33b66b0e",
   "metadata": {},
   "source": [
    "![](imgs_losses/mq_loss.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da37f2ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | hide\n",
    "# Unit tests to check MQLoss' stored quantiles\n",
    "# attribute is correctly instantiated\n",
    "check = MQLoss(level=[80, 90])\n",
    "test_eq(len(check.quantiles), 5)\n",
    "\n",
    "check = MQLoss(quantiles=[0.0100, 0.1000, 0.5, 0.9000, 0.9900])\n",
    "print(check.output_names)\n",
    "print(check.quantiles)\n",
    "test_eq(len(check.quantiles), 5)\n",
    "\n",
    "check = MQLoss(quantiles=[0.0100, 0.1000, 0.9000, 0.9900])\n",
    "test_eq(len(check.quantiles), 4)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "895ec0c0",
   "metadata": {},
   "source": [
    "## DistributionLoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "801785b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def weighted_average(x: torch.Tensor, \n",
    "                     weights: Optional[torch.Tensor]=None, dim=None) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Computes the weighted average of a given tensor across a given dim, masking\n",
    "    values associated with weight zero,\n",
    "    meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `x`: Input tensor, of which the average must be computed.<br>\n",
    "    `weights`: Weights tensor, of the same shape as `x`.<br>\n",
    "    `dim`: The dim along which to average `x`.<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `Tensor`: The tensor with values averaged along the specified `dim`.<br>\n",
    "    \"\"\"\n",
    "    if weights is not None:\n",
    "        weighted_tensor = torch.where(\n",
    "            weights != 0, x * weights, torch.zeros_like(x)\n",
    "        )\n",
    "        sum_weights = torch.clamp(\n",
    "            weights.sum(dim=dim) if dim else weights.sum(), min=1.0\n",
    "        )\n",
    "        return (\n",
    "            weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()\n",
    "        ) / sum_weights\n",
    "    else:\n",
    "        return x.mean(dim=dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83b90c8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def bernoulli_domain_map(input: torch.Tensor):\n",
    "    \"\"\" Bernoulli Domain Map\n",
    "    Maps input into distribution constraints, by construction input's \n",
    "    last dimension is of matching `distr_args` length.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `input`: tensor, of dimensions [B,T,H,theta] or [B,H,theta].<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `(probs,)`: tuple with tensors of Poisson distribution arguments.<br>\n",
    "    \"\"\"\n",
    "    return (input.squeeze(-1),)\n",
    "\n",
    "def bernoulli_scale_decouple(output, loc=None, scale=None):\n",
    "    \"\"\" Bernoulli Scale Decouple\n",
    "\n",
    "    Stabilizes model's output optimization, by learning residual\n",
    "    variance and residual location based on anchoring `loc`, `scale`.\n",
    "    Also adds Bernoulli domain protection to the distribution parameters.\n",
    "    \"\"\"\n",
    "    probs = output[0]\n",
    "    #if (loc is not None) and (scale is not None):\n",
    "    #    rate = (rate * scale) + loc\n",
    "    probs = F.sigmoid(probs)#.clone()\n",
    "    return (probs,)\n",
    "\n",
    "def student_domain_map(input: torch.Tensor):\n",
    "    \"\"\" Student T Domain Map\n",
    "    Maps input into distribution constraints, by construction input's \n",
    "    last dimension is of matching `distr_args` length.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `input`: tensor, of dimensions [B,T,H,theta] or [B,H,theta].<br>\n",
    "    `eps`: float, helps the initialization of scale for easier optimization.<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `(df, loc, scale)`: tuple with tensors of StudentT distribution arguments.<br>\n",
    "    \"\"\"\n",
    "    df, loc, scale = torch.tensor_split(input, 3, dim=-1)\n",
    "    return df.squeeze(-1), loc.squeeze(-1), scale.squeeze(-1)\n",
    "\n",
    "def student_scale_decouple(output, loc=None, scale=None, eps: float=0.1):\n",
    "    \"\"\" Normal Scale Decouple\n",
    "\n",
    "    Stabilizes model's output optimization, by learning residual\n",
    "    variance and residual location based on anchoring `loc`, `scale`.\n",
    "    Also adds StudentT domain protection to the distribution parameters.\n",
    "    \"\"\"\n",
    "    df, mean, tscale = output\n",
    "    tscale = F.softplus(tscale)\n",
    "    if (loc is not None) and (scale is not None):\n",
    "        mean = (mean * scale) + loc\n",
    "        tscale = (tscale + eps) * scale\n",
    "    df = 2.0 + F.softplus(df)\n",
    "    return (df, mean, tscale)\n",
    "\n",
    "def normal_domain_map(input: torch.Tensor):\n",
    "    \"\"\" Normal Domain Map\n",
    "    Maps input into distribution constraints, by construction input's \n",
    "    last dimension is of matching `distr_args` length.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `input`: tensor, of dimensions [B,T,H,theta] or [B,H,theta].<br>\n",
    "    `eps`: float, helps the initialization of scale for easier optimization.<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `(mean, std)`: tuple with tensors of Normal distribution arguments.<br>\n",
    "    \"\"\"\n",
    "    mean, std = torch.tensor_split(input, 2, dim=-1)\n",
    "    return mean.squeeze(-1), std.squeeze(-1)\n",
    "\n",
    "def normal_scale_decouple(output, loc=None, scale=None, eps: float=0.2):\n",
    "    \"\"\" Normal Scale Decouple\n",
    "\n",
    "    Stabilizes model's output optimization, by learning residual\n",
    "    variance and residual location based on anchoring `loc`, `scale`.\n",
    "    Also adds Normal domain protection to the distribution parameters.\n",
    "    \"\"\"\n",
    "    mean, std = output\n",
    "    std = F.softplus(std)\n",
    "    if (loc is not None) and (scale is not None):\n",
    "        mean = (mean * scale) + loc\n",
    "        std = (std + eps) * scale\n",
    "    return (mean, std)\n",
    "\n",
    "def poisson_domain_map(input: torch.Tensor):\n",
    "    \"\"\" Poisson Domain Map\n",
    "    Maps input into distribution constraints, by construction input's \n",
    "    last dimension is of matching `distr_args` length.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `input`: tensor, of dimensions [B,T,H,theta] or [B,H,theta].<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `(rate,)`: tuple with tensors of Poisson distribution arguments.<br>\n",
    "    \"\"\"\n",
    "    return (input.squeeze(-1),)\n",
    "\n",
    "def poisson_scale_decouple(output, loc=None, scale=None):\n",
    "    \"\"\" Poisson Scale Decouple\n",
    "\n",
    "    Stabilizes model's output optimization, by learning residual\n",
    "    variance and residual location based on anchoring `loc`, `scale`.\n",
    "    Also adds Poisson domain protection to the distribution parameters.\n",
    "    \"\"\"\n",
    "    eps  = 1e-10\n",
    "    rate = output[0]\n",
    "    if (loc is not None) and (scale is not None):\n",
    "        rate = (rate * scale) + loc\n",
    "    rate = F.softplus(rate) + eps\n",
    "    return (rate,)\n",
    "\n",
    "def nbinomial_domain_map(input: torch.Tensor):\n",
    "    \"\"\" Negative Binomial Domain Map\n",
    "    Maps input into distribution constraints, by construction input's \n",
    "    last dimension is of matching `distr_args` length.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `input`: tensor, of dimensions [B,T,H,theta] or [B,H,theta].<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `(total_count, alpha)`: tuple with tensors of N.Binomial distribution arguments.<br>\n",
    "    \"\"\"\n",
    "    mu, alpha = torch.tensor_split(input, 2, dim=-1)\n",
    "    return mu.squeeze(-1), alpha.squeeze(-1)\n",
    "\n",
    "def nbinomial_scale_decouple(output, loc=None, scale=None):\n",
    "    \"\"\" Negative Binomial Scale Decouple\n",
    "\n",
    "    Stabilizes model's output optimization, by learning total\n",
    "    count and logits based on anchoring `loc`, `scale`.\n",
    "    Also adds Negative Binomial domain protection to the distribution parameters.\n",
    "    \"\"\"\n",
    "    mu, alpha = output\n",
    "    mu = F.softplus(mu) + 1e-8\n",
    "    alpha = F.softplus(alpha) + 1e-8    # alpha = 1/total_counts\n",
    "    if (loc is not None) and (scale is not None):\n",
    "        mu *= loc\n",
    "        alpha /= (loc + 1.)\n",
    "\n",
    "    # mu = total_count * (probs/(1-probs))\n",
    "    # => probs = mu / (total_count + mu)\n",
    "    # => probs = mu / [total_count * (1 + mu * (1/total_count))]\n",
    "    total_count = 1.0 / alpha\n",
    "    probs = (mu * alpha / (1.0 + mu * alpha)) + 1e-8\n",
    "    return (total_count, probs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03294edd",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def est_lambda(mu, rho):\n",
    "    return mu ** (2 - rho) / (2 - rho)\n",
    "\n",
    "def est_alpha(rho):\n",
    "    return (2 - rho) / (rho - 1)\n",
    "\n",
    "def est_beta(mu, rho):\n",
    "    return mu ** (1 - rho) / (rho - 1)\n",
    "\n",
    "\n",
    "class Tweedie(Distribution):\n",
    "    \"\"\" Tweedie Distribution\n",
    "\n",
    "    The Tweedie distribution is a compound probability, special case of exponential\n",
    "    dispersion models EDMs defined by its mean-variance relationship.\n",
    "    The distribution particularly useful to model sparse series as the probability has\n",
    "    possitive mass at zero but otherwise is continuous.\n",
    "\n",
    "    $Y \\sim \\mathrm{ED}(\\\\mu,\\\\sigma^{2}) \\qquad\n",
    "    \\mathbb{P}(y|\\\\mu ,\\\\sigma^{2})=h(\\\\sigma^{2},y) \\\\exp \\\\left({\\\\frac {\\\\theta y-A(\\\\theta )}{\\\\sigma^{2}}}\\\\right)$<br>\n",
    "    \n",
    "    $\\mu =A'(\\\\theta ) \\qquad \\mathrm{Var}(Y) = \\\\sigma^{2} \\\\mu^{\\\\rho}$\n",
    "    \n",
    "    Cases of the variance relationship include Normal (`rho` = 0), Poisson (`rho` = 1),\n",
    "    Gamma (`rho` = 2), inverse Gaussian (`rho` = 3).\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `log_mu`: tensor, with log of means.<br>\n",
    "    `rho`: float, Tweedie variance power (1,2). Fixed across all observations.<br>\n",
    "    `sigma2`: tensor, Tweedie variance. Currently fixed in 1.<br>\n",
    "\n",
    "    **References:**<br>\n",
    "    - [Tweedie, M. C. K. (1984). An index which distinguishes between some important exponential families. Statistics: Applications and New Directions. \n",
    "    Proceedings of the Indian Statistical Institute Golden Jubilee International Conference (Eds. J. K. Ghosh and J. Roy), pp. 579-604. Calcutta: Indian Statistical Institute.]()<br>\n",
    "    - [Jorgensen, B. (1987). Exponential Dispersion Models. Journal of the Royal Statistical Society. \n",
    "       Series B (Methodological), 49(2), 127–162. http://www.jstor.org/stable/2345415](http://www.jstor.org/stable/2345415)<br>\n",
    "    \"\"\"\n",
    "    def __init__(self, log_mu, rho, validate_args=None):\n",
    "        # TODO: add sigma2 dispersion\n",
    "        # TODO add constraints\n",
    "        # arg_constraints = {'log_mu': constraints.real, 'rho': constraints.positive}\n",
    "        # support = constraints.real\n",
    "        self.log_mu = log_mu\n",
    "        self.rho = rho\n",
    "        assert rho>1 and rho<2, f'rho={rho} parameter needs to be between (1,2).'\n",
    "\n",
    "        batch_shape = log_mu.size()\n",
    "        super(Tweedie, self).__init__(batch_shape, validate_args=validate_args)\n",
    "\n",
    "    @property\n",
    "    def mean(self):\n",
    "        return torch.exp(self.log_mu)\n",
    "\n",
    "    @property\n",
    "    def variance(self):\n",
    "        return torch.ones_line(self.log_mu) #TODO need to be assigned\n",
    "\n",
    "    def sample(self, sample_shape=torch.Size()):\n",
    "        shape = self._extended_shape(sample_shape)\n",
    "        with torch.no_grad():\n",
    "            mu   = self.mean\n",
    "            rho  = self.rho * torch.ones_like(mu)\n",
    "            sigma2 = 1 #TODO\n",
    "\n",
    "            rate  = est_lambda(mu, rho) / sigma2  # rate for poisson\n",
    "            alpha = est_alpha(rho)                # alpha for Gamma distribution\n",
    "            beta  = est_beta(mu, rho) / sigma2    # beta for Gamma distribution\n",
    "            \n",
    "            # Expand for sample\n",
    "            rate = rate.expand(shape)\n",
    "            alpha = alpha.expand(shape)\n",
    "            beta = beta.expand(shape)\n",
    "\n",
    "            N = torch.poisson(rate)\n",
    "            gamma = torch.distributions.gamma.Gamma(N*alpha, beta)\n",
    "            samples = gamma.sample()\n",
    "            samples[N==0] = 0\n",
    "\n",
    "            return samples\n",
    "\n",
    "    def log_prob(self, y_true):\n",
    "        rho = self.rho\n",
    "        y_pred = self.log_mu\n",
    "\n",
    "        a = y_true * torch.exp((1 - rho) * y_pred) / (1 - rho)\n",
    "        b = torch.exp((2 - rho) * y_pred) / (2 - rho)\n",
    "\n",
    "        return a - b\n",
    "\n",
    "def tweedie_domain_map(input: torch.Tensor):\n",
    "    \"\"\" Tweedie Domain Map\n",
    "    Maps input into distribution constraints, by construction input's \n",
    "    last dimension is of matching `distr_args` length.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `input`: tensor, of dimensions [B,T,H,theta] or [B,H,theta].<br>\n",
    "\n",
    "    **Returns:**<br>\n",
    "    `(log_mu,)`: tuple with tensors of Tweedie distribution arguments.<br>\n",
    "    \"\"\"\n",
    "    # log_mu, probs = torch.tensor_split(input, 2, dim=-1)\n",
    "    return (input.squeeze(-1),)\n",
    "\n",
    "def tweedie_scale_decouple(output, loc=None, scale=None):\n",
    "    \"\"\" Tweedie Scale Decouple\n",
    "\n",
    "    Stabilizes model's output optimization, by learning total\n",
    "    count and logits based on anchoring `loc`, `scale`.\n",
    "    Also adds Tweedie domain protection to the distribution parameters.\n",
    "    \"\"\"\n",
    "    log_mu = output[0]\n",
    "    if (loc is not None) and (scale is not None):\n",
    "        log_mu += torch.log(loc) # TODO : rho scaling\n",
    "    return (log_mu,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5931f6c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class DistributionLoss(torch.nn.Module):\n",
    "    \"\"\" DistributionLoss\n",
    "\n",
    "    This PyTorch module wraps the `torch.distribution` classes allowing it to \n",
    "    interact with NeuralForecast models modularly. It shares the negative \n",
    "    log-likelihood as the optimization objective and a sample method to \n",
    "    generate empirically the quantiles defined by the `level` list.\n",
    "\n",
    "    Additionally, it implements a distribution transformation that factorizes the\n",
    "    scale-dependent likelihood parameters into a base scale and a multiplier \n",
    "    efficiently learnable within the network's non-linearities operating ranges.\n",
    "\n",
    "    Available distributions:<br>\n",
    "    - Poisson<br>\n",
    "    - Normal<br>\n",
    "    - StudentT<br>\n",
    "    - NegativeBinomial<br>\n",
    "    - Tweedie<br>\n",
    "    - Bernoulli (Temporal Classifiers)\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `distribution`: str, identifier of a torch.distributions.Distribution class.<br>\n",
    "    `level`: float list [0,100], confidence levels for prediction intervals.<br>\n",
    "    `quantiles`: float list [0,1], alternative to level list, target quantiles.<br>\n",
    "    `num_samples`: int=500, number of samples for the empirical quantiles.<br>\n",
    "    `return_params`: bool=False, wether or not return the Distribution parameters.<br><br>\n",
    "\n",
    "    **References:**<br>\n",
    "    - [PyTorch Probability Distributions Package: StudentT.](https://pytorch.org/docs/stable/distributions.html#studentt)<br>\n",
    "    - [David Salinas, Valentin Flunkert, Jan Gasthaus, Tim Januschowski (2020).\n",
    "       \"DeepAR: Probabilistic forecasting with autoregressive recurrent networks\". International Journal of Forecasting.](https://www.sciencedirect.com/science/article/pii/S0169207019301888)<br>\n",
    "    \"\"\"\n",
    "    def __init__(self, distribution, level=[80, 90], quantiles=None,\n",
    "                 num_samples=1000, return_params=False, **distribution_kwargs):\n",
    "       super(DistributionLoss, self).__init__()\n",
    "\n",
    "       available_distributions = dict(\n",
    "                          Bernoulli=Bernoulli,\n",
    "                          Normal=Normal,\n",
    "                          Poisson=Poisson,\n",
    "                          StudentT=StudentT,\n",
    "                          NegativeBinomial=NegativeBinomial,\n",
    "                          Tweedie=Tweedie)\n",
    "       domain_maps = dict(Bernoulli=bernoulli_domain_map,\n",
    "                          Normal=normal_domain_map,\n",
    "                          Poisson=poisson_domain_map,\n",
    "                          StudentT=student_domain_map,\n",
    "                          NegativeBinomial=nbinomial_domain_map,\n",
    "                          Tweedie=tweedie_domain_map)\n",
    "       scale_decouples = dict(\n",
    "                          Bernoulli=bernoulli_scale_decouple,\n",
    "                          Normal=normal_scale_decouple,\n",
    "                          Poisson=poisson_scale_decouple,\n",
    "                          StudentT=student_scale_decouple,\n",
    "                          NegativeBinomial=nbinomial_scale_decouple,\n",
    "                          Tweedie=tweedie_scale_decouple)\n",
    "       param_names = dict(Bernoulli=[\"-logits\"],\n",
    "                          Normal=[\"-loc\", \"-scale\"],\n",
    "                          Poisson=[\"-loc\"],\n",
    "                          StudentT=[\"-df\", \"-loc\", \"-scale\"],\n",
    "                          NegativeBinomial=[\"-total_count\", \"-logits\"],\n",
    "                          Tweedie=[\"-log_mu\"])\n",
    "       assert (distribution in available_distributions.keys()), f'{distribution} not available'\n",
    "\n",
    "       self.distribution = distribution\n",
    "       self._base_distribution = available_distributions[distribution]\n",
    "       self.domain_map = domain_maps[distribution]\n",
    "       self.scale_decouple = scale_decouples[distribution]\n",
    "       self.param_names = param_names[distribution]\n",
    "\n",
    "       self.distribution_kwargs = distribution_kwargs\n",
    "\n",
    "       qs, self.output_names = level_to_outputs(level)\n",
    "       qs = torch.Tensor(qs)\n",
    "\n",
    "        # Transform quantiles to homogeneus output names\n",
    "       if quantiles is not None:\n",
    "              _, self.output_names = quantiles_to_outputs(quantiles)\n",
    "              qs = torch.Tensor(quantiles)\n",
    "       self.quantiles = torch.nn.Parameter(qs, requires_grad=False)\n",
    "       self.num_samples = num_samples\n",
    "\n",
    "       # If True, predict_step will return Distribution's parameters\n",
    "       self.return_params = return_params\n",
    "       if self.return_params:\n",
    "            self.output_names = self.output_names + self.param_names\n",
    "\n",
    "       # Add first output entry for the sample_mean\n",
    "       self.output_names.insert(0, \"\")\n",
    "\n",
    "       self.outputsize_multiplier = len(self.param_names)\n",
    "       self.is_distribution_output = True\n",
    "\n",
    "    def get_distribution(self, distr_args, **distribution_kwargs) -> Distribution:\n",
    "        \"\"\"\n",
    "        Construct the associated Pytorch Distribution, given the collection of\n",
    "        constructor arguments and, optionally, location and scale tensors.\n",
    "\n",
    "        **Parameters**<br>\n",
    "        `distr_args`: Constructor arguments for the underlying Distribution type.<br>\n",
    "\n",
    "        **Returns**<br>\n",
    "        `Distribution`: AffineTransformed distribution.<br>\n",
    "        \"\"\"\n",
    "        # TransformedDistribution(distr, [AffineTransform(loc=loc, scale=scale)])\n",
    "        distr = self._base_distribution(*distr_args, **distribution_kwargs)\n",
    "        \n",
    "        if self.distribution =='Poisson':\n",
    "              distr.support = constraints.nonnegative\n",
    "        return distr\n",
    "\n",
    "    def sample(self,\n",
    "               distr_args: torch.Tensor,\n",
    "               num_samples: Optional[int] = None):\n",
    "        \"\"\"\n",
    "        Construct the empirical quantiles from the estimated Distribution,\n",
    "        sampling from it `num_samples` independently.\n",
    "\n",
    "        **Parameters**<br>\n",
    "        `distr_args`: Constructor arguments for the underlying Distribution type.<br>\n",
    "        `loc`: Optional tensor, of the same shape as the batch_shape + event_shape\n",
    "               of the resulting distribution.<br>\n",
    "        `scale`: Optional tensor, of the same shape as the batch_shape+event_shape \n",
    "               of the resulting distribution.<br>\n",
    "        `num_samples`: int=500, overwrite number of samples for the empirical quantiles.<br>\n",
    "\n",
    "        **Returns**<br>\n",
    "        `samples`: tensor, shape [B,H,`num_samples`].<br>\n",
    "        `quantiles`: tensor, empirical quantiles defined by `levels`.<br>\n",
    "        \"\"\"\n",
    "        if num_samples is None:\n",
    "            num_samples = self.num_samples\n",
    "\n",
    "        B, H = distr_args[0].size()\n",
    "        Q = len(self.quantiles)\n",
    "\n",
    "        # Instantiate Scaled Decoupled Distribution\n",
    "        distr = self.get_distribution(distr_args=distr_args, **self.distribution_kwargs)\n",
    "        samples = distr.sample(sample_shape=(num_samples,))\n",
    "        samples = samples.permute(1,2,0) # [samples,B,H] -> [B,H,samples]\n",
    "        samples = samples.to(distr_args[0].device)\n",
    "        samples = samples.view(B*H, num_samples)\n",
    "        sample_mean = torch.mean(samples, dim=-1)\n",
    "\n",
    "        # Compute quantiles\n",
    "        quantiles_device = self.quantiles.to(distr_args[0].device)\n",
    "        quants = torch.quantile(input=samples, \n",
    "                                q=quantiles_device, dim=1)\n",
    "        quants = quants.permute((1,0)) # [Q, B*H] -> [B*H, Q]\n",
    "\n",
    "        # Final reshapes\n",
    "        samples = samples.view(B, H, num_samples)\n",
    "        sample_mean = sample_mean.view(B, H, 1)\n",
    "        quants  = quants.view(B, H, Q)\n",
    "\n",
    "        return samples, sample_mean, quants\n",
    "\n",
    "    def __call__(self,\n",
    "                 y: torch.Tensor,\n",
    "                 distr_args: torch.Tensor,\n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        Computes the negative log-likelihood objective function. \n",
    "        To estimate the following predictive distribution:\n",
    "\n",
    "        $$\\mathrm{P}(\\mathbf{y}_{\\\\tau}\\,|\\,\\\\theta) \\\\quad \\mathrm{and} \\\\quad -\\log(\\mathrm{P}(\\mathbf{y}_{\\\\tau}\\,|\\,\\\\theta))$$\n",
    "\n",
    "        where $\\\\theta$ represents the distributions parameters. It aditionally \n",
    "        summarizes the objective signal using a weighted average using the `mask` tensor. \n",
    "\n",
    "        **Parameters**<br>\n",
    "        `y`: tensor, Actual values.<br>\n",
    "        `distr_args`: Constructor arguments for the underlying Distribution type.<br>\n",
    "        `loc`: Optional tensor, of the same shape as the batch_shape + event_shape\n",
    "               of the resulting distribution.<br>\n",
    "        `scale`: Optional tensor, of the same shape as the batch_shape+event_shape \n",
    "               of the resulting distribution.<br>\n",
    "        `mask`: tensor, Specifies date stamps per serie to consider in loss.<br>\n",
    "\n",
    "        **Returns**<br>\n",
    "        `loss`: scalar, weighted loss function against which backpropagation will be performed.<br>\n",
    "        \"\"\"\n",
    "        # Instantiate Scaled Decoupled Distribution\n",
    "        distr = self.get_distribution(distr_args=distr_args, **self.distribution_kwargs)\n",
    "        loss_values = -distr.log_prob(y)\n",
    "        loss_weights = mask\n",
    "        return weighted_average(loss_values, weights=loss_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a462101b",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(DistributionLoss, name='DistributionLoss.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8c367f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(DistributionLoss.sample, name='DistributionLoss.sample', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04e32679",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(DistributionLoss.__call__, name='DistributionLoss.__call__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14a7e381",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | hide\n",
    "# Unit tests to check DistributionLoss' stored quantiles\n",
    "# attribute is correctly instantiated\n",
    "check = DistributionLoss(distribution='Normal', level=[80, 90])\n",
    "test_eq(len(check.quantiles), 5)\n",
    "\n",
    "check = DistributionLoss(distribution='Normal', \n",
    "                         quantiles=[0.0100, 0.1000, 0.5, 0.9000, 0.9900])\n",
    "print(check.output_names)\n",
    "print(check.quantiles)\n",
    "test_eq(len(check.quantiles), 5)\n",
    "\n",
    "check = DistributionLoss(distribution='Normal',\n",
    "                         quantiles=[0.0100, 0.1000, 0.9000, 0.9900])\n",
    "test_eq(len(check.quantiles), 4)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "07f459b8",
   "metadata": {},
   "source": [
    "## Poisson Mixture Mesh (PMM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46ec688f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class PMM(torch.nn.Module):\n",
    "    \"\"\" Poisson Mixture Mesh\n",
    "\n",
    "    This Poisson Mixture statistical model assumes independence across groups of \n",
    "    data $\\mathcal{G}=\\{[g_{i}]\\}$, and estimates relationships within the group.\n",
    "\n",
    "    $$ \\mathrm{P}\\\\left(\\mathbf{y}_{[b][t+1:t+H]}\\\\right) = \n",
    "    \\prod_{ [g_{i}] \\in \\mathcal{G}} \\mathrm{P} \\\\left(\\mathbf{y}_{[g_{i}][\\\\tau]} \\\\right) =\n",
    "    \\prod_{\\\\beta\\in[g_{i}]} \n",
    "    \\\\left(\\sum_{k=1}^{K} w_k \\prod_{(\\\\beta,\\\\tau) \\in [g_i][t+1:t+H]} \\mathrm{Poisson}(y_{\\\\beta,\\\\tau}, \\hat{\\\\lambda}_{\\\\beta,\\\\tau,k}) \\\\right)$$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `n_components`: int=10, the number of mixture components.<br>\n",
    "    `level`: float list [0,100], confidence levels for prediction intervals.<br>\n",
    "    `quantiles`: float list [0,1], alternative to level list, target quantiles.<br>\n",
    "    `return_params`: bool=False, wether or not return the Distribution parameters.<br>\n",
    "    `batch_correlation`: bool=False, wether or not model batch correlations.<br>\n",
    "    `horizon_correlation`: bool=False, wether or not model horizon correlations.<br>\n",
    "\n",
    "    **References:**<br>\n",
    "    [Kin G. Olivares, O. Nganba Meetei, Ruijun Ma, Rohan Reddy, Mengfei Cao, Lee Dicker. \n",
    "    Probabilistic Hierarchical Forecasting with Deep Poisson Mixtures. Submitted to the International \n",
    "    Journal Forecasting, Working paper available at arxiv.](https://arxiv.org/pdf/2110.13179.pdf)\n",
    "    \"\"\"\n",
    "    def __init__(self, n_components=10, level=[80, 90], quantiles=None,\n",
    "                 num_samples=1000, return_params=False,\n",
    "                 batch_correlation=False, horizon_correlation=False):\n",
    "        super(PMM, self).__init__()\n",
    "        # Transform level to MQLoss parameters\n",
    "        qs, self.output_names = level_to_outputs(level)\n",
    "        qs = torch.Tensor(qs)\n",
    "\n",
    "        # Transform quantiles to homogeneus output names\n",
    "        if quantiles is not None:\n",
    "            _, self.output_names = quantiles_to_outputs(quantiles)\n",
    "            qs = torch.Tensor(quantiles)\n",
    "        self.quantiles = torch.nn.Parameter(qs, requires_grad=False)\n",
    "        self.num_samples = num_samples\n",
    "        self.batch_correlation = batch_correlation\n",
    "        self.horizon_correlation = horizon_correlation\n",
    "\n",
    "        # If True, predict_step will return Distribution's parameters\n",
    "        self.return_params = return_params\n",
    "        if self.return_params:\n",
    "            self.param_names = [f\"-lambda-{i}\" for i in range(1, n_components + 1)]\n",
    "            self.output_names = self.output_names + self.param_names\n",
    "\n",
    "        # Add first output entry for the sample_mean\n",
    "        self.output_names.insert(0, \"\")\n",
    "\n",
    "        self.outputsize_multiplier = n_components\n",
    "        self.is_distribution_output = True\n",
    "\n",
    "    def domain_map(self, output: torch.Tensor):\n",
    "        return (output,)#, weights\n",
    "        \n",
    "    def scale_decouple(self, \n",
    "                       output,\n",
    "                       loc: Optional[torch.Tensor] = None,\n",
    "                       scale: Optional[torch.Tensor] = None):\n",
    "        \"\"\" Scale Decouple\n",
    "\n",
    "        Stabilizes model's output optimization, by learning residual\n",
    "        variance and residual location based on anchoring `loc`, `scale`.\n",
    "        Also adds domain protection to the distribution parameters.\n",
    "        \"\"\"\n",
    "        lambdas = output[0]\n",
    "        if (loc is not None) and (scale is not None):\n",
    "            loc = loc.view(lambdas.size(dim=0), 1, -1)\n",
    "            scale = scale.view(lambdas.size(dim=0), 1, -1)\n",
    "            lambdas = (lambdas * scale) + loc\n",
    "        lambdas = F.softplus(lambdas)\n",
    "        return (lambdas,)\n",
    "\n",
    "    def sample(self, distr_args, num_samples=None):\n",
    "        \"\"\"\n",
    "        Construct the empirical quantiles from the estimated Distribution,\n",
    "        sampling from it `num_samples` independently.\n",
    "\n",
    "        **Parameters**<br>\n",
    "        `distr_args`: Constructor arguments for the underlying Distribution type.<br>\n",
    "        `loc`: Optional tensor, of the same shape as the batch_shape + event_shape\n",
    "               of the resulting distribution.<br>\n",
    "        `scale`: Optional tensor, of the same shape as the batch_shape+event_shape \n",
    "               of the resulting distribution.<br>\n",
    "        `num_samples`: int=500, overwrites number of samples for the empirical quantiles.<br>\n",
    "\n",
    "        **Returns**<br>\n",
    "        `samples`: tensor, shape [B,H,`num_samples`].<br>\n",
    "        `quantiles`: tensor, empirical quantiles defined by `levels`.<br>\n",
    "        \"\"\"\n",
    "        if num_samples is None:\n",
    "            num_samples = self.num_samples\n",
    "\n",
    "        lambdas = distr_args[0]\n",
    "        B, H, K = lambdas.size()\n",
    "        Q = len(self.quantiles)\n",
    "\n",
    "        # Sample K ~ Mult(weights)\n",
    "        # shared across B, H\n",
    "        # weights = torch.repeat_interleave(input=weights, repeats=H, dim=2)\n",
    "        weights = (1/K) * torch.ones_like(lambdas).to(lambdas.device)\n",
    "\n",
    "        # Avoid loop, vectorize\n",
    "        weights = weights.reshape(-1, K)\n",
    "        lambdas = lambdas.flatten()        \n",
    "\n",
    "        # Vectorization trick to recover row_idx\n",
    "        sample_idxs = torch.multinomial(input=weights, \n",
    "                                        num_samples=num_samples,\n",
    "                                        replacement=True)\n",
    "        aux_col_idx = torch.unsqueeze(torch.arange(B*H),-1) * K\n",
    "\n",
    "        # To device\n",
    "        sample_idxs = sample_idxs.to(lambdas.device)\n",
    "        aux_col_idx = aux_col_idx.to(lambdas.device)\n",
    "\n",
    "        sample_idxs = sample_idxs + aux_col_idx\n",
    "        sample_idxs = sample_idxs.flatten()\n",
    "\n",
    "        sample_lambdas = lambdas[sample_idxs]\n",
    "\n",
    "        # Sample y ~ Poisson(lambda) independently\n",
    "        samples = torch.poisson(sample_lambdas).to(lambdas.device)\n",
    "        samples = samples.view(B*H, num_samples)\n",
    "        sample_mean = torch.mean(samples, dim=-1)\n",
    "\n",
    "        # Compute quantiles\n",
    "        quantiles_device = self.quantiles.to(lambdas.device)\n",
    "        quants = torch.quantile(input=samples, q=quantiles_device, dim=1)\n",
    "        quants = quants.permute((1,0)) # Q, B*H\n",
    "\n",
    "        # Final reshapes\n",
    "        samples = samples.view(B, H, num_samples)\n",
    "        sample_mean = sample_mean.view(B, H, 1)\n",
    "        quants  = quants.view(B, H, Q)\n",
    "\n",
    "        return samples, sample_mean, quants\n",
    "    \n",
    "    def neglog_likelihood(self,\n",
    "                          y: torch.Tensor,\n",
    "                          distr_args: Tuple[torch.Tensor],\n",
    "                          mask: Union[torch.Tensor, None] = None,):\n",
    "        if mask is None: \n",
    "            mask = (y > 0) * 1\n",
    "        else:\n",
    "            mask = mask * ((y > 0) * 1)\n",
    "\n",
    "        eps  = 1e-10\n",
    "        lambdas = distr_args[0]\n",
    "        B, H, K = lambdas.size()\n",
    "\n",
    "        weights = (1/K) * torch.ones_like(lambdas).to(lambdas.device)\n",
    "\n",
    "        y = y[:,:,None]\n",
    "        mask = mask[:,:,None]\n",
    "\n",
    "        y = y * mask # Protect y negative entries\n",
    "        \n",
    "        # Single Poisson likelihood\n",
    "        log_pi = y.xlogy(lambdas + eps) - lambdas - (y + 1).lgamma()\n",
    "\n",
    "        if self.batch_correlation:\n",
    "            log_pi  = torch.sum(log_pi, dim=0, keepdim=True)\n",
    "\n",
    "        if self.horizon_correlation:\n",
    "            log_pi  = torch.sum(log_pi, dim=1, keepdim=True)\n",
    "\n",
    "        # Numerically Stable Mixture loglikelihood\n",
    "        loglik = torch.logsumexp((torch.log(weights) + log_pi), dim=2, keepdim=True)\n",
    "        loglik = loglik * mask\n",
    "\n",
    "        mean   = torch.sum(weights * lambdas, axis=-1, keepdims=True)\n",
    "        reglrz = torch.mean(torch.square(y - mean) * mask)\n",
    "        loss   = -torch.mean(loglik) + 0.001 * reglrz\n",
    "        return loss\n",
    "\n",
    "    def __call__(self, y: torch.Tensor,\n",
    "                 distr_args: Tuple[torch.Tensor],\n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "\n",
    "        return self.neglog_likelihood(y=y, distr_args=distr_args, mask=mask)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62d7daba",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(PMM, name='PMM.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa8da65c",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(PMM.sample, name='PMM.sample', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba75717c",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(PMM.__call__, name='PMM.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f7518450",
   "metadata": {},
   "source": [
    "![](imgs_losses/pmm.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4a20e21",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | hide\n",
    "# Unit tests to check PMM's stored quantiles\n",
    "# attribute is correctly instantiated\n",
    "check = PMM(n_components=2, level=[80, 90])\n",
    "test_eq(len(check.quantiles), 5)\n",
    "\n",
    "check = PMM(n_components=2, \n",
    "            quantiles=[0.0100, 0.1000, 0.5, 0.9000, 0.9900])\n",
    "print(check.output_names)\n",
    "print(check.quantiles)\n",
    "test_eq(len(check.quantiles), 5)\n",
    "\n",
    "check = PMM(n_components=2,\n",
    "            quantiles=[0.0100, 0.1000, 0.9000, 0.9900])\n",
    "test_eq(len(check.quantiles), 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a56a2fbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Create single mixture and broadcast to N,H,K\n",
    "weights = torch.ones((1,3))[None, :, :]\n",
    "lambdas = torch.Tensor([[5,10,15], [10,20,30]])[None, :, :]\n",
    "\n",
    "# Create repetitions for the batch dimension N.\n",
    "N=2\n",
    "weights = torch.repeat_interleave(input=weights, repeats=N, dim=0)\n",
    "lambdas = torch.repeat_interleave(input=lambdas, repeats=N, dim=0)\n",
    "\n",
    "print('weights.shape (N,H,K) \\t', weights.shape)\n",
    "print('lambdas.shape (N,H,K) \\t', lambdas.shape)\n",
    "\n",
    "distr = PMM(quantiles=[0.1, 0.40, 0.5, 0.60, 0.9])\n",
    "distr_args = (lambdas,)\n",
    "samples, sample_mean, quants = distr.sample(distr_args)\n",
    "\n",
    "print('samples.shape (N,H,num_samples) ', samples.shape)\n",
    "print('sample_mean.shape (N,H) ', sample_mean.shape)\n",
    "print('quants.shape  (N,H,Q) \\t\\t', quants.shape)\n",
    "\n",
    "# Plot synthethic data\n",
    "x_plot = range(quants.shape[1]) # H length\n",
    "y_plot_hat = quants[0,:,:]  # Filter N,G,T -> H,Q\n",
    "samples_hat = samples[0,:,:]  # Filter N,G,T -> H,num_samples\n",
    "\n",
    "# Kernel density plot for single forecast horizon \\tau = t+1\n",
    "fig, ax = plt.subplots(figsize=(3.7, 2.9))\n",
    "\n",
    "ax.hist(samples_hat[0,:], alpha=0.5, label=r'Horizon $\\tau+1$')\n",
    "ax.hist(samples_hat[1,:], alpha=0.5, label=r'Horizon $\\tau+2$')\n",
    "ax.set(xlabel='Y values', ylabel='Probability')\n",
    "plt.title('Single horizon Distributions')\n",
    "plt.legend(bbox_to_anchor=(1, 1), loc='upper left', ncol=1)\n",
    "plt.grid()\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# Plot simulated trajectory\n",
    "fig, ax = plt.subplots(figsize=(3.7, 2.9))\n",
    "plt.plot(x_plot, y_plot_hat[:,2], color='black', label='median [q50]')\n",
    "plt.fill_between(x_plot,\n",
    "                 y1=y_plot_hat[:,1], y2=y_plot_hat[:,3],\n",
    "                 facecolor='blue', alpha=0.4, label='[p25-p75]')\n",
    "plt.fill_between(x_plot,\n",
    "                 y1=y_plot_hat[:,0], y2=y_plot_hat[:,4],\n",
    "                 facecolor='blue', alpha=0.2, label='[p1-p99]')\n",
    "ax.set(xlabel='Horizon', ylabel='Y values')\n",
    "plt.title('PMM Probabilistic Predictions')\n",
    "plt.legend(bbox_to_anchor=(1, 1), loc='upper left', ncol=1)\n",
    "plt.grid()\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e84e0dd4",
   "metadata": {},
   "source": [
    "## Gaussian Mixture Mesh (GMM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6928b0c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class GMM(torch.nn.Module):\n",
    "    \"\"\" Gaussian Mixture Mesh\n",
    "\n",
    "    This Gaussian Mixture statistical model assumes independence across groups of \n",
    "    data $\\mathcal{G}=\\{[g_{i}]\\}$, and estimates relationships within the group.\n",
    "\n",
    "    $$ \\mathrm{P}\\\\left(\\mathbf{y}_{[b][t+1:t+H]}\\\\right) = \n",
    "    \\prod_{ [g_{i}] \\in \\mathcal{G}} \\mathrm{P}\\left(\\mathbf{y}_{[g_{i}][\\\\tau]}\\\\right)=\n",
    "    \\prod_{\\\\beta\\in[g_{i}]}\n",
    "    \\\\left(\\sum_{k=1}^{K} w_k \\prod_{(\\\\beta,\\\\tau) \\in [g_i][t+1:t+H]} \n",
    "    \\mathrm{Gaussian}(y_{\\\\beta,\\\\tau}, \\hat{\\mu}_{\\\\beta,\\\\tau,k}, \\sigma_{\\\\beta,\\\\tau,k})\\\\right)$$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `n_components`: int=10, the number of mixture components.<br>\n",
    "    `level`: float list [0,100], confidence levels for prediction intervals.<br>\n",
    "    `quantiles`: float list [0,1], alternative to level list, target quantiles.<br>\n",
    "    `return_params`: bool=False, wether or not return the Distribution parameters.<br>\n",
    "    `batch_correlation`: bool=False, wether or not model batch correlations.<br>\n",
    "    `horizon_correlation`: bool=False, wether or not model horizon correlations.<br><br>\n",
    "\n",
    "    **References:**<br>\n",
    "    [Kin G. Olivares, O. Nganba Meetei, Ruijun Ma, Rohan Reddy, Mengfei Cao, Lee Dicker. \n",
    "    Probabilistic Hierarchical Forecasting with Deep Poisson Mixtures. Submitted to the International \n",
    "    Journal Forecasting, Working paper available at arxiv.](https://arxiv.org/pdf/2110.13179.pdf)\n",
    "    \"\"\"\n",
    "    def __init__(self, n_components=1, level=[80, 90], quantiles=None, \n",
    "                 num_samples=1000, return_params=False,\n",
    "                 batch_correlation=False, horizon_correlation=False):\n",
    "        super(GMM, self).__init__()\n",
    "        # Transform level to MQLoss parameters\n",
    "        qs, self.output_names = level_to_outputs(level)\n",
    "        qs = torch.Tensor(qs)\n",
    "\n",
    "        # Transform quantiles to homogeneus output names\n",
    "        if quantiles is not None:\n",
    "            _, self.output_names = quantiles_to_outputs(quantiles)\n",
    "            qs = torch.Tensor(quantiles)\n",
    "        self.quantiles = torch.nn.Parameter(qs, requires_grad=False)\n",
    "        self.num_samples = num_samples\n",
    "        self.batch_correlation = batch_correlation\n",
    "        self.horizon_correlation = horizon_correlation        \n",
    "\n",
    "        # If True, predict_step will return Distribution's parameters\n",
    "        self.return_params = return_params\n",
    "        if self.return_params:\n",
    "            mu_names = [f\"-mu-{i}\" for i in range(1, n_components + 1)]\n",
    "            std_names = [f\"-std-{i}\" for i in range(1, n_components + 1)]\n",
    "            mu_std_names = [i for j in zip(mu_names, std_names) for i in j]\n",
    "            self.output_names = self.output_names + mu_std_names\n",
    "\n",
    "        # Add first output entry for the sample_mean\n",
    "        self.output_names.insert(0, \"\")\n",
    "\n",
    "        self.outputsize_multiplier = 2 * n_components\n",
    "        self.is_distribution_output = True\n",
    "\n",
    "    def domain_map(self, output: torch.Tensor):\n",
    "        means, stds = torch.tensor_split(output, 2, dim=-1)\n",
    "        return (means, stds)\n",
    "\n",
    "    def scale_decouple(self, \n",
    "                       output,\n",
    "                       loc: Optional[torch.Tensor] = None,\n",
    "                       scale: Optional[torch.Tensor] = None,\n",
    "                       eps: float=0.2):\n",
    "        \"\"\" Scale Decouple\n",
    "\n",
    "        Stabilizes model's output optimization, by learning residual\n",
    "        variance and residual location based on anchoring `loc`, `scale`.\n",
    "        Also adds domain protection to the distribution parameters.\n",
    "        \"\"\"\n",
    "        means, stds = output\n",
    "        stds = F.softplus(stds)\n",
    "        if (loc is not None) and (scale is not None):\n",
    "            loc = loc.view(means.size(dim=0), 1, -1)\n",
    "            scale = scale.view(means.size(dim=0), 1, -1)            \n",
    "            means = (means * scale) + loc\n",
    "            stds = (stds + eps) * scale\n",
    "        return (means, stds)\n",
    "\n",
    "    def sample(self, distr_args, num_samples=None):\n",
    "        \"\"\"\n",
    "        Construct the empirical quantiles from the estimated Distribution,\n",
    "        sampling from it `num_samples` independently.\n",
    "\n",
    "        **Parameters**<br>\n",
    "        `distr_args`: Constructor arguments for the underlying Distribution type.<br>\n",
    "        `loc`: Optional tensor, of the same shape as the batch_shape + event_shape\n",
    "               of the resulting distribution.<br>\n",
    "        `scale`: Optional tensor, of the same shape as the batch_shape+event_shape \n",
    "               of the resulting distribution.<br>\n",
    "        `num_samples`: int=500, number of samples for the empirical quantiles.<br>\n",
    "\n",
    "        **Returns**<br>\n",
    "        `samples`: tensor, shape [B,H,`num_samples`].<br>\n",
    "        `quantiles`: tensor, empirical quantiles defined by `levels`.<br>\n",
    "        \"\"\"\n",
    "        if num_samples is None:\n",
    "            num_samples = self.num_samples\n",
    "            \n",
    "        means, stds = distr_args\n",
    "        B, H, K = means.size()\n",
    "        Q = len(self.quantiles)\n",
    "        assert means.shape == stds.shape\n",
    "\n",
    "        # Sample K ~ Mult(weights)\n",
    "        # shared across B, H\n",
    "        # weights = torch.repeat_interleave(input=weights, repeats=H, dim=2)\n",
    "        \n",
    "        weights = (1/K) * torch.ones_like(means).to(means.device)\n",
    "        \n",
    "        # Avoid loop, vectorize\n",
    "        weights = weights.reshape(-1, K)\n",
    "        means = means.flatten()\n",
    "        stds = stds.flatten()\n",
    "\n",
    "        # Vectorization trick to recover row_idx\n",
    "        sample_idxs = torch.multinomial(input=weights, \n",
    "                                        num_samples=num_samples,\n",
    "                                        replacement=True)\n",
    "        aux_col_idx = torch.unsqueeze(torch.arange(B*H),-1) * K\n",
    "\n",
    "        # To device\n",
    "        sample_idxs = sample_idxs.to(means.device)\n",
    "        aux_col_idx = aux_col_idx.to(means.device)\n",
    "\n",
    "        sample_idxs = sample_idxs + aux_col_idx\n",
    "        sample_idxs = sample_idxs.flatten()\n",
    "\n",
    "        sample_means = means[sample_idxs]\n",
    "        sample_stds  = stds[sample_idxs]\n",
    "\n",
    "        # Sample y ~ Normal(mu, std) independently\n",
    "        samples = torch.normal(sample_means, sample_stds).to(means.device)\n",
    "        samples = samples.view(B*H, num_samples)\n",
    "        sample_mean = torch.mean(samples, dim=-1)\n",
    "\n",
    "        # Compute quantiles\n",
    "        quantiles_device = self.quantiles.to(means.device)\n",
    "        quants = torch.quantile(input=samples, q=quantiles_device, dim=1)\n",
    "        quants = quants.permute((1,0)) # Q, B*H\n",
    "\n",
    "        # Final reshapes\n",
    "        samples = samples.view(B, H, num_samples)\n",
    "        sample_mean = sample_mean.view(B, H, 1)\n",
    "        quants  = quants.view(B, H, Q)\n",
    "\n",
    "        return samples, sample_mean, quants\n",
    "\n",
    "    def neglog_likelihood(self,\n",
    "                          y: torch.Tensor,\n",
    "                          distr_args: Tuple[torch.Tensor, torch.Tensor],\n",
    "                          mask: Union[torch.Tensor, None] = None):\n",
    "\n",
    "        if mask is None: \n",
    "            mask = torch.ones_like(y)\n",
    "            \n",
    "        means, stds = distr_args\n",
    "        B, H, K = means.size()\n",
    "        \n",
    "        weights = (1/K) * torch.ones_like(means).to(means.device)\n",
    "        \n",
    "        y = y[:,:, None]\n",
    "        mask = mask[:,:,None]\n",
    "        \n",
    "        var = stds ** 2\n",
    "        log_stds = torch.log(stds)\n",
    "        log_pi = - ((y - means) ** 2 / (2 * var)) - log_stds \\\n",
    "                 - math.log(math.sqrt(2 * math.pi))\n",
    "\n",
    "        if self.batch_correlation:\n",
    "            log_pi  = torch.sum(log_pi, dim=0, keepdim=True)\n",
    "\n",
    "        if self.horizon_correlation:    \n",
    "            log_pi  = torch.sum(log_pi, dim=1, keepdim=True)\n",
    "\n",
    "        # Numerically Stable Mixture loglikelihood\n",
    "        loglik = torch.logsumexp((torch.log(weights) + log_pi), dim=2, keepdim=True)\n",
    "        loglik  = loglik * mask\n",
    "\n",
    "        loss = -torch.mean(loglik)\n",
    "        return loss\n",
    "    \n",
    "    def __call__(self, y: torch.Tensor,\n",
    "                 distr_args: Tuple[torch.Tensor, torch.Tensor],\n",
    "                 mask: Union[torch.Tensor, None] = None,):\n",
    "\n",
    "        return self.neglog_likelihood(y=y, distr_args=distr_args, mask=mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec4ebf3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(GMM, name='GMM.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bea56d8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(GMM.sample, name='GMM.sample', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f16e4f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(GMM.__call__, name='GMM.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "aed232a4",
   "metadata": {},
   "source": [
    "![](imgs_losses/gmm.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ebe4250",
   "metadata": {},
   "outputs": [],
   "source": [
    "# | hide\n",
    "# Unit tests to check PMM's stored quantiles\n",
    "# attribute is correctly instantiated\n",
    "check = GMM(n_components=2, level=[80, 90])\n",
    "test_eq(len(check.quantiles), 5)\n",
    "\n",
    "check = GMM(n_components=2, \n",
    "            quantiles=[0.0100, 0.1000, 0.5, 0.9000, 0.9900])\n",
    "print(check.output_names)\n",
    "print(check.quantiles)\n",
    "test_eq(len(check.quantiles), 5)\n",
    "\n",
    "check = GMM(n_components=2,\n",
    "            quantiles=[0.0100, 0.1000, 0.9000, 0.9900])\n",
    "test_eq(len(check.quantiles), 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "684d2382",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Create single mixture and broadcast to N,H,K\n",
    "means   = torch.Tensor([[5,10,15], [10,20,30]])[None, :, :]\n",
    "\n",
    "# # Create repetitions for the batch dimension N.\n",
    "N=2\n",
    "means = torch.repeat_interleave(input=means, repeats=N, dim=0)\n",
    "weights = torch.ones_like(means)\n",
    "stds  = torch.ones_like(means)\n",
    "\n",
    "print('weights.shape (N,H,K) \\t', weights.shape)\n",
    "print('means.shape (N,H,K) \\t', means.shape)\n",
    "print('stds.shape (N,H,K) \\t', stds.shape)\n",
    "\n",
    "distr = GMM(quantiles=[0.1, 0.40, 0.5, 0.60, 0.9])\n",
    "distr_args = (means, stds)\n",
    "samples, sample_mean, quants = distr.sample(distr_args)\n",
    "\n",
    "print('samples.shape (N,H,num_samples) ', samples.shape)\n",
    "print('sample_mean.shape (N,H) ', sample_mean.shape)\n",
    "print('quants.shape  (N,H,Q) \\t\\t', quants.shape)\n",
    "\n",
    "# Plot synthethic data\n",
    "x_plot = range(quants.shape[1]) # H length\n",
    "y_plot_hat = quants[0,:,:]  # Filter N,G,T -> H,Q\n",
    "samples_hat = samples[0,:,:]  # Filter N,G,T -> H,num_samples\n",
    "\n",
    "# Kernel density plot for single forecast horizon \\tau = t+1\n",
    "fig, ax = plt.subplots(figsize=(3.7, 2.9))\n",
    "\n",
    "ax.hist(samples_hat[0,:], alpha=0.5, bins=50,\n",
    "        label=r'Horizon $\\tau+1$')\n",
    "ax.hist(samples_hat[1,:], alpha=0.5, bins=50,\n",
    "        label=r'Horizon $\\tau+2$')\n",
    "ax.set(xlabel='Y values', ylabel='Probability')\n",
    "plt.title('Single horizon Distributions')\n",
    "plt.legend(bbox_to_anchor=(1, 1), loc='upper left', ncol=1)\n",
    "plt.grid()\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# Plot simulated trajectory\n",
    "fig, ax = plt.subplots(figsize=(3.7, 2.9))\n",
    "plt.plot(x_plot, y_plot_hat[:,2], color='black', label='median [q50]')\n",
    "plt.fill_between(x_plot,\n",
    "                 y1=y_plot_hat[:,1], y2=y_plot_hat[:,3],\n",
    "                 facecolor='blue', alpha=0.4, label='[p25-p75]')\n",
    "plt.fill_between(x_plot,\n",
    "                 y1=y_plot_hat[:,0], y2=y_plot_hat[:,4],\n",
    "                 facecolor='blue', alpha=0.2, label='[p1-p99]')\n",
    "ax.set(xlabel='Horizon', ylabel='Y values')\n",
    "plt.title('GMM Probabilistic Predictions')\n",
    "plt.legend(bbox_to_anchor=(1, 1), loc='upper left', ncol=1)\n",
    "plt.grid()\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "694a2afe",
   "metadata": {},
   "source": [
    "## Negative Binomial Mixture Mesh (NBMM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cdbe5c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class NBMM(torch.nn.Module):\n",
    "    \"\"\" Negative Binomial Mixture Mesh\n",
    "\n",
    "    This N. Binomial Mixture statistical model assumes independence across groups of \n",
    "    data $\\mathcal{G}=\\{[g_{i}]\\}$, and estimates relationships within the group.\n",
    "\n",
    "    $$ \\mathrm{P}\\\\left(\\mathbf{y}_{[b][t+1:t+H]}\\\\right) = \n",
    "    \\prod_{ [g_{i}] \\in \\mathcal{G}} \\mathrm{P}\\left(\\mathbf{y}_{[g_{i}][\\\\tau]}\\\\right)=\n",
    "    \\prod_{\\\\beta\\in[g_{i}]}\n",
    "    \\\\left(\\sum_{k=1}^{K} w_k \\prod_{(\\\\beta,\\\\tau) \\in [g_i][t+1:t+H]} \n",
    "    \\mathrm{NBinomial}(y_{\\\\beta,\\\\tau}, \\hat{r}_{\\\\beta,\\\\tau,k}, \\hat{p}_{\\\\beta,\\\\tau,k})\\\\right)$$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `n_components`: int=10, the number of mixture components.<br>\n",
    "    `level`: float list [0,100], confidence levels for prediction intervals.<br>\n",
    "    `quantiles`: float list [0,1], alternative to level list, target quantiles.<br>\n",
    "    `return_params`: bool=False, wether or not return the Distribution parameters.<br><br>\n",
    "\n",
    "    **References:**<br>\n",
    "    [Kin G. Olivares, O. Nganba Meetei, Ruijun Ma, Rohan Reddy, Mengfei Cao, Lee Dicker. \n",
    "    Probabilistic Hierarchical Forecasting with Deep Poisson Mixtures. Submitted to the International \n",
    "    Journal Forecasting, Working paper available at arxiv.](https://arxiv.org/pdf/2110.13179.pdf)\n",
    "    \"\"\"\n",
    "    def __init__(self, n_components=1, level=[80, 90], quantiles=None, \n",
    "                 num_samples=1000, return_params=False):\n",
    "        super(NBMM, self).__init__()\n",
    "        # Transform level to MQLoss parameters\n",
    "        qs, self.output_names = level_to_outputs(level)\n",
    "        qs = torch.Tensor(qs)\n",
    "\n",
    "        # Transform quantiles to homogeneus output names\n",
    "        if quantiles is not None:\n",
    "            _, self.output_names = quantiles_to_outputs(quantiles)\n",
    "            qs = torch.Tensor(quantiles)\n",
    "        self.quantiles = torch.nn.Parameter(qs, requires_grad=False)\n",
    "        self.num_samples = num_samples\n",
    "\n",
    "        # If True, predict_step will return Distribution's parameters\n",
    "        self.return_params = return_params\n",
    "        if self.return_params:\n",
    "            total_count_names = [f\"-total_count-{i}\" for i in range(1, n_components + 1)]\n",
    "            probs_names = [f\"-probs-{i}\" for i in range(1, n_components + 1)]\n",
    "            param_names = [i for j in zip(total_count_names, probs_names) for i in j]\n",
    "            self.output_names = self.output_names + param_names\n",
    "\n",
    "        # Add first output entry for the sample_mean\n",
    "        self.output_names.insert(0, \"\")            \n",
    "\n",
    "        self.outputsize_multiplier = 2 * n_components\n",
    "        self.is_distribution_output = True\n",
    "\n",
    "    def domain_map(self, output: torch.Tensor):\n",
    "        mu, alpha = torch.tensor_split(output, 2, dim=-1)\n",
    "        return (mu, alpha)\n",
    "\n",
    "    def scale_decouple(self, \n",
    "                       output,\n",
    "                       loc: Optional[torch.Tensor] = None,\n",
    "                       scale: Optional[torch.Tensor] = None,\n",
    "                       eps: float=0.2):\n",
    "        \"\"\" Scale Decouple\n",
    "\n",
    "        Stabilizes model's output optimization, by learning residual\n",
    "        variance and residual location based on anchoring `loc`, `scale`.\n",
    "        Also adds domain protection to the distribution parameters.\n",
    "        \"\"\"\n",
    "        # Efficient NBinomial parametrization\n",
    "        mu, alpha = output\n",
    "        mu = F.softplus(mu) + 1e-8\n",
    "        alpha = F.softplus(alpha) + 1e-8    # alpha = 1/total_counts\n",
    "        if (loc is not None) and (scale is not None):\n",
    "            loc = loc.view(mu.size(dim=0), 1, -1)\n",
    "            mu *= loc\n",
    "            alpha /= (loc + 1.)\n",
    "\n",
    "        # mu = total_count * (probs/(1-probs))\n",
    "        # => probs = mu / (total_count + mu)\n",
    "        # => probs = mu / [total_count * (1 + mu * (1/total_count))]\n",
    "        total_count = 1.0 / alpha\n",
    "        probs = (mu * alpha / (1.0 + mu * alpha)) + 1e-8 \n",
    "        return (total_count, probs)\n",
    "\n",
    "    def sample(self, distr_args, num_samples=None):\n",
    "        \"\"\"\n",
    "        Construct the empirical quantiles from the estimated Distribution,\n",
    "        sampling from it `num_samples` independently.\n",
    "\n",
    "        **Parameters**<br>\n",
    "        `distr_args`: Constructor arguments for the underlying Distribution type.<br>\n",
    "        `loc`: Optional tensor, of the same shape as the batch_shape + event_shape\n",
    "               of the resulting distribution.<br>\n",
    "        `scale`: Optional tensor, of the same shape as the batch_shape+event_shape \n",
    "               of the resulting distribution.<br>\n",
    "        `num_samples`: int=500, number of samples for the empirical quantiles.<br>\n",
    "\n",
    "        **Returns**<br>\n",
    "        `samples`: tensor, shape [B,H,`num_samples`].<br>\n",
    "        `quantiles`: tensor, empirical quantiles defined by `levels`.<br>\n",
    "        \"\"\"\n",
    "        if num_samples is None:\n",
    "            num_samples = self.num_samples\n",
    "            \n",
    "        total_count, probs = distr_args\n",
    "        B, H, K = total_count.size()\n",
    "        Q = len(self.quantiles)\n",
    "        assert total_count.shape == probs.shape\n",
    "\n",
    "        # Sample K ~ Mult(weights)\n",
    "        # shared across B, H\n",
    "        # weights = torch.repeat_interleave(input=weights, repeats=H, dim=2)\n",
    "        \n",
    "        weights = (1/K) * torch.ones_like(probs).to(probs.device)\n",
    "        \n",
    "        # Avoid loop, vectorize\n",
    "        weights = weights.reshape(-1, K)\n",
    "        total_count = total_count.flatten()\n",
    "        probs = probs.flatten()\n",
    "\n",
    "        # Vectorization trick to recover row_idx\n",
    "        sample_idxs = torch.multinomial(input=weights, \n",
    "                                        num_samples=num_samples,\n",
    "                                        replacement=True)\n",
    "        aux_col_idx = torch.unsqueeze(torch.arange(B*H),-1) * K\n",
    "\n",
    "        # To device\n",
    "        sample_idxs = sample_idxs.to(probs.device)\n",
    "        aux_col_idx = aux_col_idx.to(probs.device)\n",
    "\n",
    "        sample_idxs = sample_idxs + aux_col_idx\n",
    "        sample_idxs = sample_idxs.flatten()\n",
    "\n",
    "        sample_total_count = total_count[sample_idxs]\n",
    "        sample_probs  = probs[sample_idxs]\n",
    "\n",
    "        # Sample y ~ NBinomial(total_count, probs) independently\n",
    "        dist = NegativeBinomial(total_count=sample_total_count, \n",
    "                                probs=sample_probs)\n",
    "        samples = dist.sample(sample_shape=(1,)).to(probs.device)[0]\n",
    "        samples = samples.view(B*H, num_samples)\n",
    "        sample_mean = torch.mean(samples, dim=-1)\n",
    "\n",
    "        # Compute quantiles\n",
    "        quantiles_device = self.quantiles.to(probs.device)\n",
    "        quants = torch.quantile(input=samples, q=quantiles_device, dim=1)\n",
    "        quants = quants.permute((1,0)) # Q, B*H\n",
    "\n",
    "        # Final reshapes\n",
    "        samples = samples.view(B, H, num_samples)\n",
    "        sample_mean = sample_mean.view(B, H, 1)\n",
    "        quants  = quants.view(B, H, Q)\n",
    "\n",
    "        return samples, sample_mean, quants\n",
    "\n",
    "    def neglog_likelihood(self,\n",
    "                          y: torch.Tensor,\n",
    "                          distr_args: Tuple[torch.Tensor, torch.Tensor],\n",
    "                          mask: Union[torch.Tensor, None] = None):\n",
    "\n",
    "        if mask is None: \n",
    "            mask = torch.ones_like(y)\n",
    "            \n",
    "        total_count, probs = distr_args\n",
    "        B, H, K = total_count.size()\n",
    "        \n",
    "        weights = (1/K) * torch.ones_like(probs).to(probs.device)\n",
    "        \n",
    "        y = y[:,:, None]\n",
    "        mask = mask[:,:,None]\n",
    "\n",
    "        log_unnormalized_prob = (total_count * torch.log(1.-probs) + y * torch.log(probs))\n",
    "        log_normalization = (-torch.lgamma(total_count + y) + torch.lgamma(1. + y) +\n",
    "                             torch.lgamma(total_count))\n",
    "        log_normalization[total_count + y == 0.] = 0.\n",
    "        log =  log_unnormalized_prob - log_normalization\n",
    "\n",
    "        #log  = torch.sum(log, dim=0, keepdim=True) # Joint within batch/group\n",
    "        #log  = torch.sum(log, dim=1, keepdim=True) # Joint within horizon\n",
    "\n",
    "        # Numerical stability mixture and loglik\n",
    "        log_max = torch.amax(log, dim=2, keepdim=True) # [1,1,K] (collapsed joints)\n",
    "        lik     = weights * torch.exp(log-log_max)     # Take max\n",
    "        loglik  = torch.log(torch.sum(lik, dim=2, keepdim=True)) + log_max # Return max\n",
    "        \n",
    "        loglik  = loglik * mask #replace with mask\n",
    "\n",
    "        loss = -torch.mean(loglik)\n",
    "        return loss\n",
    "    \n",
    "    def __call__(self, y: torch.Tensor,\n",
    "                 distr_args: Tuple[torch.Tensor, torch.Tensor],\n",
    "                 mask: Union[torch.Tensor, None] = None,):\n",
    "\n",
    "        return self.neglog_likelihood(y=y, distr_args=distr_args, mask=mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eed5e73c",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(NBMM, name='NBMM.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41ea98ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(NBMM.sample, name='NBMM.sample', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c7189c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(NBMM.__call__, name='NBMM.__call__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b67e2931",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Create single mixture and broadcast to N,H,K\n",
    "counts   = torch.Tensor([[10,20,30], [20,40,60]])[None, :, :]\n",
    "\n",
    "# # Create repetitions for the batch dimension N.\n",
    "N=2\n",
    "counts = torch.repeat_interleave(input=counts, repeats=N, dim=0)\n",
    "weights = torch.ones_like(counts)\n",
    "probs  = torch.ones_like(counts) * 0.5\n",
    "\n",
    "print('weights.shape (N,H,K) \\t', weights.shape)\n",
    "print('counts.shape (N,H,K) \\t', counts.shape)\n",
    "print('probs.shape (N,H,K) \\t', probs.shape)\n",
    "\n",
    "model = NBMM(quantiles=[0.1, 0.40, 0.5, 0.60, 0.9])\n",
    "distr_args = (counts, probs)\n",
    "samples, sample_mean, quants = model.sample(distr_args, num_samples=2000)\n",
    "\n",
    "print('samples.shape (N,H,num_samples) ', samples.shape)\n",
    "print('sample_mean.shape (N,H) ', sample_mean.shape)\n",
    "print('quants.shape  (N,H,Q) \\t\\t', quants.shape)\n",
    "\n",
    "# Plot synthethic data\n",
    "x_plot = range(quants.shape[1]) # H length\n",
    "y_plot_hat = quants[0,:,:]  # Filter N,G,T -> H,Q\n",
    "samples_hat = samples[0,:,:]  # Filter N,G,T -> H,num_samples\n",
    "\n",
    "# Kernel density plot for single forecast horizon \\tau = t+1\n",
    "fig, ax = plt.subplots(figsize=(3.7, 2.9))\n",
    "\n",
    "ax.hist(samples_hat[0,:], alpha=0.5, bins=30,\n",
    "        label=r'Horizon $\\tau+1$')\n",
    "ax.hist(samples_hat[1,:], alpha=0.5, bins=30,\n",
    "        label=r'Horizon $\\tau+2$')\n",
    "ax.set(xlabel='Y values', ylabel='Probability')\n",
    "plt.title('Single horizon Distributions')\n",
    "plt.legend(bbox_to_anchor=(1, 1), loc='upper left', ncol=1)\n",
    "plt.grid()\n",
    "plt.show()\n",
    "plt.close()\n",
    "\n",
    "# Plot simulated trajectory\n",
    "fig, ax = plt.subplots(figsize=(3.7, 2.9))\n",
    "plt.plot(x_plot, y_plot_hat[:,2], color='black', label='median [q50]')\n",
    "plt.fill_between(x_plot,\n",
    "                 y1=y_plot_hat[:,1], y2=y_plot_hat[:,3],\n",
    "                 facecolor='blue', alpha=0.4, label='[p25-p75]')\n",
    "plt.fill_between(x_plot,\n",
    "                 y1=y_plot_hat[:,0], y2=y_plot_hat[:,4],\n",
    "                 facecolor='blue', alpha=0.2, label='[p1-p99]')\n",
    "ax.set(xlabel='Horizon', ylabel='Y values')\n",
    "plt.title('NBM Probabilistic Predictions')\n",
    "plt.legend(bbox_to_anchor=(1, 1), loc='upper left', ncol=1)\n",
    "plt.grid()\n",
    "plt.show()\n",
    "plt.close()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a6cf4850",
   "metadata": {},
   "source": [
    "# 5. Robustified Errors\n",
    "\n",
    "This type of errors from robust statistic focus on methods resistant to outliers and violations of assumptions, providing reliable estimates and inferences. Robust estimators are used to reduce the impact of outliers, offering more stable results."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7588f6d2",
   "metadata": {},
   "source": [
    "## Huber Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ae9f60c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class HuberLoss(BasePointLoss):\n",
    "    \"\"\" Huber Loss\n",
    "\n",
    "    The Huber loss, employed in robust regression, is a loss function that \n",
    "    exhibits reduced sensitivity to outliers in data when compared to the \n",
    "    squared error loss. This function is also refered as SmoothL1.\n",
    "\n",
    "    The Huber loss function is quadratic for small errors and linear for large \n",
    "    errors, with equal values and slopes of the different sections at the two \n",
    "    points where $(y_{\\\\tau}-\\hat{y}_{\\\\tau})^{2}$=$|y_{\\\\tau}-\\hat{y}_{\\\\tau}|$.\n",
    "\n",
    "    $$ L_{\\delta}(y_{\\\\tau},\\; \\hat{y}_{\\\\tau})\n",
    "    =\\\\begin{cases}{\\\\frac{1}{2}}(y_{\\\\tau}-\\hat{y}_{\\\\tau})^{2}\\;{\\\\text{for }}|y_{\\\\tau}-\\hat{y}_{\\\\tau}|\\leq \\delta \\\\\\ \n",
    "    \\\\delta \\ \\cdot \\left(|y_{\\\\tau}-\\hat{y}_{\\\\tau}|-{\\\\frac {1}{2}}\\delta \\\\right),\\;{\\\\text{otherwise.}}\\end{cases}$$\n",
    "\n",
    "    where $\\\\delta$ is a threshold parameter that determines the point at which the loss transitions from quadratic to linear,\n",
    "    and can be tuned to control the trade-off between robustness and accuracy in the predictions.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `delta`: float=1.0, Specifies the threshold at which to change between delta-scaled L1 and L2 loss.\n",
    "    `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>\n",
    "    \n",
    "    **References:**<br>\n",
    "    [Huber Peter, J (1964). \"Robust Estimation of a Location Parameter\". Annals of Statistics](https://projecteuclid.org/journals/annals-of-mathematical-statistics/volume-35/issue-1/Robust-Estimation-of-a-Location-Parameter/10.1214/aoms/1177703732.full)\n",
    "    \"\"\"   \n",
    "    def __init__(self, delta: float=1., horizon_weight=None):\n",
    "        super(HuberLoss, self).__init__(horizon_weight=horizon_weight,\n",
    "                                  outputsize_multiplier=1,\n",
    "                                  output_names=[''])\n",
    "        self.delta = delta\n",
    "\n",
    "    def __call__(self,\n",
    "                 y: torch.Tensor,\n",
    "                 y_hat: torch.Tensor,\n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor, Actual values.<br>\n",
    "        `y_hat`: tensor, Predicted values.<br>\n",
    "        `mask`: tensor, Specifies date stamps per serie to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `huber_loss`: tensor (single value).\n",
    "        \"\"\"\n",
    "        losses = F.huber_loss(y, y_hat, reduction='none', delta=self.delta)        \n",
    "        weights = self._compute_weights(y=y, mask=mask)\n",
    "        return _weighted_mean(losses=losses, weights=weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ccbfa88",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(HuberLoss, name='HuberLoss.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6226178b",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(HuberLoss.__call__, name='HuberLoss.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "06aad81b",
   "metadata": {},
   "source": [
    "![](imgs_losses/huber_loss.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7f835621",
   "metadata": {},
   "source": [
    "## Tukey Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26ea3109",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TukeyLoss(torch.nn.Module):\n",
    "    \"\"\" Tukey Loss\n",
    "\n",
    "    The Tukey loss function, also known as Tukey's biweight function, is a \n",
    "    robust statistical loss function used in robust statistics. Tukey's loss exhibits\n",
    "    quadratic behavior near the origin, like the Huber loss; however, it is even more\n",
    "    robust to outliers as the loss for large residuals remains constant instead of \n",
    "    scaling linearly.\n",
    "\n",
    "    The parameter $c$ in Tukey's loss determines the ''saturation'' point\n",
    "    of the function: Higher values of $c$ enhance sensitivity, while lower values \n",
    "    increase resistance to outliers.\n",
    "\n",
    "    $$ L_{c}(y_{\\\\tau},\\; \\hat{y}_{\\\\tau})\n",
    "    =\\\\begin{cases}{\n",
    "    \\\\frac{c^{2}}{6}} \\\\left[1-(\\\\frac{y_{\\\\tau}-\\hat{y}_{\\\\tau}}{c})^{2} \\\\right]^{3}    \\;\\\\text{for } |y_{\\\\tau}-\\hat{y}_{\\\\tau}|\\leq c \\\\\\ \n",
    "    \\\\frac{c^{2}}{6} \\qquad \\\\text{otherwise.}  \\end{cases}$$\n",
    "\n",
    "    Please note that the Tukey loss function assumes the data to be stationary or\n",
    "    normalized beforehand. If the error values are excessively large, the algorithm\n",
    "    may need help to converge during optimization. It is advisable to employ small learning rates.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `c`: float=4.685, Specifies the Tukey loss' threshold on which residuals are no longer considered.<br>\n",
    "    `normalize`: bool=True, Wether normalization is performed within Tukey loss' computation.<br>\n",
    "\n",
    "    **References:**<br>\n",
    "    [Beaton, A. E., and Tukey, J. W. (1974). \"The Fitting of Power Series, Meaning Polynomials, Illustrated on Band-Spectroscopic Data.\"](https://www.jstor.org/stable/1267936)\n",
    "    \"\"\"\n",
    "    def __init__(self, c: float=4.685, normalize: bool=True):\n",
    "        super(TukeyLoss, self).__init__()\n",
    "        self.outputsize_multiplier = 1\n",
    "        self.c = c\n",
    "        self.normalize = normalize\n",
    "        self.output_names = ['']\n",
    "        self.is_distribution_output = False\n",
    "\n",
    "    def domain_map(self, y_hat: torch.Tensor):\n",
    "        \"\"\"\n",
    "        Univariate loss operates in dimension [B,T,H]/[B,H]\n",
    "        This changes the network's output from [B,H,1]->[B,H]\n",
    "        \"\"\"\n",
    "        return y_hat.squeeze(-1)\n",
    "\n",
    "    def masked_mean(self, x, mask, dim):\n",
    "        x_nan = x.masked_fill(mask < 1, float(\"nan\"))\n",
    "        x_mean = x_nan.nanmean(dim=dim, keepdim=True)\n",
    "        x_mean = torch.nan_to_num(x_mean, nan=0.0)\n",
    "        return x_mean\n",
    "\n",
    "    def __call__(self, y: torch.Tensor, y_hat: torch.Tensor, \n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor, Actual values.<br>\n",
    "        `y_hat`: tensor, Predicted values.<br>\n",
    "        `mask`: tensor, Specifies date stamps per serie to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `tukey_loss`: tensor (single value).\n",
    "        \"\"\"\n",
    "        if mask is None:\n",
    "            mask = torch.ones_like(y_hat)\n",
    "\n",
    "        # We normalize the Tukey loss, to satisfy 4.685 normal outlier bounds\n",
    "        if self.normalize:\n",
    "            y_mean = self.masked_mean(x=y, mask=mask, dim=-1)\n",
    "            y_std = torch.sqrt(self.masked_mean(x=(y - y_mean) ** 2, mask=mask, dim=-1)) + 1e-2\n",
    "        else:\n",
    "            y_std = 1.\n",
    "        delta_y = torch.abs(y - y_hat) / y_std\n",
    "\n",
    "        tukey_mask = torch.greater_equal(self.c * torch.ones_like(delta_y), delta_y)\n",
    "        tukey_loss = tukey_mask * mask * (1-(delta_y/(self.c))**2)**3 + (1-(tukey_mask * 1))\n",
    "        tukey_loss = (self.c**2 / 6) * torch.mean(tukey_loss)\n",
    "        return tukey_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd4653e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TukeyLoss, name='TukeyLoss.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7686462",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(TukeyLoss.__call__, name='TukeyLoss.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8ae50f25",
   "metadata": {},
   "source": [
    "![](imgs_losses/tukey_loss.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a8a28d9c",
   "metadata": {},
   "source": [
    "## Huberized Quantile Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "549e6bdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class HuberQLoss(BasePointLoss):\n",
    "    \"\"\" Huberized Quantile Loss\n",
    "\n",
    "    The Huberized quantile loss is a modified version of the quantile loss function that\n",
    "    combines the advantages of the quantile loss and the Huber loss. It is commonly used\n",
    "    in regression tasks, especially when dealing with data that contains outliers or heavy tails.\n",
    "\n",
    "    The Huberized quantile loss between `y` and `y_hat` measure the Huber Loss in a non-symmetric way.\n",
    "    The loss pays more attention to under/over-estimation depending on the quantile parameter $q$; \n",
    "    and controls the trade-off between robustness and accuracy in the predictions with the parameter $delta$.\n",
    "\n",
    "    $$ \\mathrm{HuberQL}(\\\\mathbf{y}_{\\\\tau}, \\\\mathbf{\\hat{y}}^{(q)}_{\\\\tau}) = \n",
    "    (1-q)\\, L_{\\delta}(y_{\\\\tau},\\; \\hat{y}^{(q)}_{\\\\tau}) \\mathbb{1}\\{ \\hat{y}^{(q)}_{\\\\tau} \\geq y_{\\\\tau} \\} + \n",
    "    q\\, L_{\\delta}(y_{\\\\tau},\\; \\hat{y}^{(q)}_{\\\\tau}) \\mathbb{1}\\{ \\hat{y}^{(q)}_{\\\\tau} < y_{\\\\tau} \\} $$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `delta`: float=1.0, Specifies the threshold at which to change between delta-scaled L1 and L2 loss.<br>\n",
    "    `q`: float, between 0 and 1. The slope of the quantile loss, in the context of quantile regression, the q determines the conditional quantile level.<br>\n",
    "    `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>\n",
    "\n",
    "    **References:**<br>\n",
    "    [Huber Peter, J (1964). \"Robust Estimation of a Location Parameter\". Annals of Statistics](https://projecteuclid.org/journals/annals-of-mathematical-statistics/volume-35/issue-1/Robust-Estimation-of-a-Location-Parameter/10.1214/aoms/1177703732.full)<br>\n",
    "    [Roger Koenker and Gilbert Bassett, Jr., \"Regression Quantiles\".](https://www.jstor.org/stable/1913643)\n",
    "    \"\"\"\n",
    "    def __init__(self, q, delta: float=1., horizon_weight=None):\n",
    "        super(HuberQLoss, self).__init__(horizon_weight=horizon_weight,\n",
    "                                           outputsize_multiplier=1,\n",
    "                                           output_names=[f'_q{q}_d{delta}'])\n",
    "        self.q = q\n",
    "        self.delta = delta\n",
    "\n",
    "    def __call__(self,\n",
    "                 y: torch.Tensor,\n",
    "                 y_hat: torch.Tensor,\n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor, Actual values.<br>\n",
    "        `y_hat`: tensor, Predicted values.<br>\n",
    "        `mask`: tensor, Specifies datapoints to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `huber_qloss`: tensor (single value).\n",
    "        \"\"\"\n",
    "        error  = y_hat - y\n",
    "        zero_error = torch.zeros_like(error)\n",
    "        sq     = torch.maximum(-error, zero_error)\n",
    "        s1_q   = torch.maximum(error, zero_error)\n",
    "        losses = self.q * F.huber_loss(sq, zero_error, \n",
    "                                       reduction='none', delta=self.delta) + \\\n",
    "                 (1 - self.q) * F.huber_loss(s1_q, zero_error, \n",
    "                                        reduction='none', delta=self.delta)\n",
    "\n",
    "        weights = self._compute_weights(y=y, mask=mask)\n",
    "        return _weighted_mean(losses=losses, weights=weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec830ac0",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(HuberQLoss, name='HuberQLoss.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15409d3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(HuberQLoss.__call__, name='HuberQLoss.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a2d97f31",
   "metadata": {},
   "source": [
    "![](imgs_losses/huber_qloss.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "2e7e3143",
   "metadata": {},
   "source": [
    "## Huberized MQLoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc992c47",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class HuberMQLoss(BasePointLoss):\n",
    "    \"\"\"  Huberized Multi-Quantile loss\n",
    "\n",
    "    The Huberized Multi-Quantile loss (HuberMQL) is a modified version of the multi-quantile loss function \n",
    "    that combines the advantages of the quantile loss and the Huber loss. HuberMQL is commonly used in regression \n",
    "    tasks, especially when dealing with data that contains outliers or heavy tails. The loss function pays \n",
    "    more attention to under/over-estimation depending on the quantile list $[q_{1},q_{2},\\dots]$ parameter. \n",
    "    It controls the trade-off between robustness and prediction accuracy with the parameter $\\\\delta$.\n",
    "\n",
    "    $$ \\mathrm{HuberMQL}_{\\delta}(\\\\mathbf{y}_{\\\\tau},[\\\\mathbf{\\hat{y}}^{(q_{1})}_{\\\\tau}, ... ,\\hat{y}^{(q_{n})}_{\\\\tau}]) = \n",
    "    \\\\frac{1}{n} \\\\sum_{q_{i}} \\mathrm{HuberQL}_{\\\\delta}(\\\\mathbf{y}_{\\\\tau}, \\\\mathbf{\\hat{y}}^{(q_{i})}_{\\\\tau}) $$\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `level`: int list [0,100]. Probability levels for prediction intervals (Defaults median).\n",
    "    `quantiles`: float list [0., 1.]. Alternative to level, quantiles to estimate from y distribution.\n",
    "    `delta`: float=1.0, Specifies the threshold at which to change between delta-scaled L1 and L2 loss.<br>   \n",
    "    `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br> \n",
    "\n",
    "    **References:**<br>\n",
    "    [Huber Peter, J (1964). \"Robust Estimation of a Location Parameter\". Annals of Statistics](https://projecteuclid.org/journals/annals-of-mathematical-statistics/volume-35/issue-1/Robust-Estimation-of-a-Location-Parameter/10.1214/aoms/1177703732.full)<br>\n",
    "    [Roger Koenker and Gilbert Bassett, Jr., \"Regression Quantiles\".](https://www.jstor.org/stable/1913643)\n",
    "    \"\"\"\n",
    "    def __init__(self, level=[80, 90], quantiles=None, delta: float=1.0, horizon_weight=None):\n",
    "\n",
    "        qs, output_names = level_to_outputs(level)\n",
    "        qs = torch.Tensor(qs)\n",
    "        # Transform quantiles to homogeneus output names\n",
    "        if quantiles is not None:\n",
    "            _, output_names = quantiles_to_outputs(quantiles)\n",
    "            qs = torch.Tensor(quantiles)\n",
    "\n",
    "        super(HuberMQLoss, self).__init__(horizon_weight=horizon_weight,\n",
    "                                     outputsize_multiplier=len(qs),\n",
    "                                     output_names=output_names)\n",
    "        \n",
    "        self.quantiles = torch.nn.Parameter(qs, requires_grad=False)\n",
    "        self.delta = delta\n",
    "\n",
    "    def domain_map(self, y_hat: torch.Tensor):\n",
    "        \"\"\"\n",
    "        Identity domain map [B,T,H,Q]/[B,H,Q]\n",
    "        \"\"\"\n",
    "        return y_hat\n",
    "    \n",
    "    def _compute_weights(self, y, mask):\n",
    "        \"\"\"\n",
    "        Compute final weights for each datapoint (based on all weights and all masks)\n",
    "        Set horizon_weight to a ones[H] tensor if not set.\n",
    "        If set, check that it has the same length as the horizon in x.\n",
    "        \"\"\"\n",
    "        if mask is None:\n",
    "            mask = torch.ones_like(y).to(y.device)\n",
    "        else:\n",
    "            mask = mask.unsqueeze(1) # Add Q dimension.\n",
    "\n",
    "        if self.horizon_weight is None:\n",
    "            self.horizon_weight = torch.ones(mask.shape[-1])\n",
    "        else:\n",
    "            assert mask.shape[-1] == len(self.horizon_weight), \\\n",
    "                'horizon_weight must have same length as Y'\n",
    "    \n",
    "        weights = self.horizon_weight.clone()\n",
    "        weights = torch.ones_like(mask, device=mask.device) * weights.to(mask.device)\n",
    "        return weights * mask\n",
    "\n",
    "    def __call__(self,\n",
    "                 y: torch.Tensor,\n",
    "                 y_hat: torch.Tensor,\n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor, Actual values.<br>\n",
    "        `y_hat`: tensor, Predicted values.<br>\n",
    "        `mask`: tensor, Specifies date stamps per serie to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `hmqloss`: tensor (single value).\n",
    "        \"\"\"\n",
    "\n",
    "        error  = y_hat - y.unsqueeze(-1)\n",
    "        zero_error = torch.zeros_like(error)        \n",
    "        sq     = torch.maximum(-error, torch.zeros_like(error))\n",
    "        s1_q   = torch.maximum(error, torch.zeros_like(error))\n",
    "        losses = F.huber_loss(self.quantiles * sq, zero_error, \n",
    "                                        reduction='none', delta=self.delta) + \\\n",
    "                  F.huber_loss((1 - self.quantiles) * s1_q, zero_error, \n",
    "                                reduction='none', delta=self.delta)\n",
    "        losses = (1/len(self.quantiles)) * losses\n",
    "\n",
    "        if y_hat.ndim == 3: # BaseWindows\n",
    "            losses = losses.swapaxes(-2,-1) # [B,H,Q] -> [B,Q,H] (needed for horizon weighting, H at the end)\n",
    "        elif y_hat.ndim == 4: # BaseRecurrent\n",
    "            losses = losses.swapaxes(-2,-1)\n",
    "            losses = losses.swapaxes(-2,-3) # [B,seq_len,H,Q] -> [B,Q,seq_len,H] (needed for horizon weighting, H at the end)\n",
    "\n",
    "        weights = self._compute_weights(y=losses, mask=mask) # Use losses for extra dim\n",
    "        # NOTE: Weights do not have Q dimension.\n",
    "\n",
    "        return _weighted_mean(losses=losses, weights=weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a662632",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(HuberMQLoss, name='HuberMQLoss.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82f733ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(HuberMQLoss.__call__, name='HuberMQLoss.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "47782e38",
   "metadata": {},
   "source": [
    "![](imgs_losses/hmq_loss.png)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "eb99f88b",
   "metadata": {},
   "source": [
    "# 6. Others"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "013d1502",
   "metadata": {},
   "source": [
    "## Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4fda0a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Accuracy(torch.nn.Module):\n",
    "    \"\"\" Accuracy\n",
    "\n",
    "    Computes the accuracy between categorical `y` and `y_hat`.\n",
    "    This evaluation metric is only meant for evalution, as it\n",
    "    is not differentiable.\n",
    "\n",
    "    $$ \\mathrm{Accuracy}(\\\\mathbf{y}_{\\\\tau}, \\\\mathbf{\\hat{y}}_{\\\\tau}) = \\\\frac{1}{H} \\\\sum^{t+H}_{\\\\tau=t+1} \\mathrm{1}\\{\\\\mathbf{y}_{\\\\tau}==\\\\mathbf{\\hat{y}}_{\\\\tau}\\} $$\n",
    "\n",
    "    \"\"\"\n",
    "    def __init__(self,):\n",
    "        super(Accuracy, self).__init__()\n",
    "        self.is_distribution_output = False\n",
    "\n",
    "    def domain_map(self, y_hat: torch.Tensor):\n",
    "        \"\"\"\n",
    "        Univariate loss operates in dimension [B,T,H]/[B,H]\n",
    "        This changes the network's output from [B,H,1]->[B,H]\n",
    "        \"\"\"\n",
    "        return y_hat.squeeze(-1)\n",
    "\n",
    "    def __call__(self, y: torch.Tensor, y_hat: torch.Tensor, \n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor, Actual values.<br>\n",
    "        `y_hat`: tensor, Predicted values.<br>\n",
    "        `mask`: tensor, Specifies date stamps per serie to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `accuracy`: tensor (single value).\n",
    "        \"\"\"\n",
    "        if mask is None:\n",
    "            mask = torch.ones_like(y_hat)\n",
    "\n",
    "        measure = (y.unsqueeze(-1) == y_hat) * mask.unsqueeze(-1)\n",
    "        accuracy = torch.mean(measure)\n",
    "        return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eeb2d06",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(Accuracy, name='Accuracy.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2111646c",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(Accuracy.__call__, name='Accuracy.__call__', title_level=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "3742e6be",
   "metadata": {},
   "source": [
    "## Scaled Continuous Ranked Probability Score (sCRPS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d210a2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class sCRPS(torch.nn.Module):\n",
    "    \"\"\"Scaled Continues Ranked Probability Score\n",
    "\n",
    "    Calculates a scaled variation of the CRPS, as proposed by Rangapuram (2021),\n",
    "    to measure the accuracy of predicted quantiles `y_hat` compared to the observation `y`.\n",
    "\n",
    "    This metric averages percentual weighted absolute deviations as \n",
    "    defined by the quantile losses.\n",
    "\n",
    "    $$ \\mathrm{sCRPS}(\\\\mathbf{\\hat{y}}^{(q)}_{\\\\tau}, \\mathbf{y}_{\\\\tau}) = \\\\frac{2}{N} \\sum_{i}\n",
    "    \\int^{1}_{0}\n",
    "    \\\\frac{\\mathrm{QL}(\\\\mathbf{\\hat{y}}^{(q}_{\\\\tau} y_{i,\\\\tau})_{q}}{\\sum_{i} | y_{i,\\\\tau} |} dq $$\n",
    "\n",
    "    where $\\\\mathbf{\\hat{y}}^{(q}_{\\\\tau}$ is the estimated quantile, and $y_{i,\\\\tau}$\n",
    "    are the target variable realizations.\n",
    "\n",
    "    **Parameters:**<br>\n",
    "    `level`: int list [0,100]. Probability levels for prediction intervals (Defaults median).\n",
    "    `quantiles`: float list [0., 1.]. Alternative to level, quantiles to estimate from y distribution.\n",
    "\n",
    "    **References:**<br>\n",
    "    - [Gneiting, Tilmann. (2011). \\\"Quantiles as optimal point forecasts\\\". \n",
    "    International Journal of Forecasting.](https://www.sciencedirect.com/science/article/pii/S0169207010000063)<br>\n",
    "    - [Spyros Makridakis, Evangelos Spiliotis, Vassilios Assimakopoulos, Zhi Chen, Anil Gaba, Ilia Tsetlin, Robert L. Winkler. (2022). \n",
    "    \\\"The M5 uncertainty competition: Results, findings and conclusions\\\". \n",
    "    International Journal of Forecasting.](https://www.sciencedirect.com/science/article/pii/S0169207021001722)<br>\n",
    "    - [Syama Sundar Rangapuram, Lucien D Werner, Konstantinos Benidis, Pedro Mercado, Jan Gasthaus, Tim Januschowski. (2021). \n",
    "    \\\"End-to-End Learning of Coherent Probabilistic Forecasts for Hierarchical Time Series\\\". \n",
    "    Proceedings of the 38th International Conference on Machine Learning (ICML).](https://proceedings.mlr.press/v139/rangapuram21a.html)\n",
    "    \"\"\"\n",
    "    def __init__(self, level=[80, 90], quantiles=None):\n",
    "        super(sCRPS, self).__init__()\n",
    "        self.mql = MQLoss(level=level, quantiles=quantiles)\n",
    "        self.is_distribution_output = False\n",
    "    \n",
    "    def __call__(self, y: torch.Tensor, y_hat: torch.Tensor, \n",
    "                 mask: Union[torch.Tensor, None] = None):\n",
    "        \"\"\"\n",
    "        **Parameters:**<br>\n",
    "        `y`: tensor, Actual values.<br>\n",
    "        `y_hat`: tensor, Predicted values.<br>\n",
    "        `mask`: tensor, Specifies date stamps per series to consider in loss.<br>\n",
    "\n",
    "        **Returns:**<br>\n",
    "        `scrps`: tensor (single value).\n",
    "        \"\"\"\n",
    "        mql = self.mql(y=y, y_hat=y_hat, mask=mask)\n",
    "        norm = torch.sum(torch.abs(y))\n",
    "        unmean = torch.sum(mask)\n",
    "        scrps = 2 * mql * unmean / (norm + 1e-5)\n",
    "        return scrps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53770648",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(sCRPS, name='sCRPS.__init__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3646250f",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_doc(sCRPS.__call__, name='sCRPS.__call__', title_level=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cdfa174",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "# Each 1 is an error, there are 6 datapoints.\n",
    "y = torch.Tensor([[0,0,0],[0,0,0]])\n",
    "y_hat = torch.Tensor([[0,0,1],[1,0,1]])\n",
    "\n",
    "# Complete mask and horizon_weight\n",
    "mask = torch.Tensor([[1,1,1],[1,1,1]])\n",
    "horizon_weight = torch.Tensor([1,1,1])\n",
    "\n",
    "mae = MAE(horizon_weight=horizon_weight)\n",
    "loss = mae(y=y, y_hat=y_hat, mask=mask)\n",
    "assert loss==(3/6), 'Should be 3/6'\n",
    "\n",
    "# Incomplete mask and complete horizon_weight\n",
    "mask = torch.Tensor([[1,1,1],[0,1,1]]) # Only 1 error and points is masked.\n",
    "horizon_weight = torch.Tensor([1,1,1])\n",
    "mae = MAE(horizon_weight=horizon_weight)\n",
    "loss = mae(y=y, y_hat=y_hat, mask=mask)\n",
    "assert loss==(2/5), 'Should be 2/5'\n",
    "\n",
    "# Complete mask and incomplete horizon_weight\n",
    "mask = torch.Tensor([[1,1,1],[1,1,1]])\n",
    "horizon_weight = torch.Tensor([1,1,0]) # 2 errors and points are masked.\n",
    "mae = MAE(horizon_weight=horizon_weight)\n",
    "loss = mae(y=y, y_hat=y_hat, mask=mask)\n",
    "assert loss==(1/4), 'Should be 1/4'\n",
    "\n",
    "# Incomplete mask and incomplete horizon_weight\n",
    "mask = torch.Tensor([[0,1,1],[1,1,1]])\n",
    "horizon_weight = torch.Tensor([1,1,0]) # 2 errors are masked, and 3 points.\n",
    "mae = MAE(horizon_weight=horizon_weight)\n",
    "loss = mae(y=y, y_hat=y_hat, mask=mask)\n",
    "assert loss==(1/3), 'Should be 1/3'"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
